Authors: Daniel Meier (Swiss Re Ltd), Michael Mayer (La Mobilière), members of the Data Science Working Group of the Swiss Association of Actuaries, see https://actuarialdatascience.org, and Juan-Ramón Troncoso-Pastoriza (Tune Insight)
This notebook introduces privacy preserving machine learning methods using synthetic health datasets with risk factors like BMI, blood pressure, age, etc. to predict various health outcomes of individuals over time. The notebook consists of 5 parts:
Importing all packages used throughout the notebook.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sn
import plotly.graph_objects as go
import statsmodels.formula.api as sm
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Input, Dense, Activation
from sklearn.metrics import auc, roc_auc_score #, log_loss
from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve
from scipy.special import logit
import random
import sys
if not 'lifelines' in sys.modules:
%pip install lifelines
import lifelines as ll
if not 'dalex' in sys.modules:
%pip install dalex
import dalex as dx
Collecting lifelines
Downloading lifelines-0.27.8-py3-none-any.whl (350 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 350.7/350.7 kB 5.6 MB/s eta 0:00:00
Requirement already satisfied: numpy<2.0,>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from lifelines) (1.23.5)
Requirement already satisfied: scipy>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from lifelines) (1.11.2)
Requirement already satisfied: pandas>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from lifelines) (1.5.3)
Requirement already satisfied: matplotlib>=3.0 in /usr/local/lib/python3.10/dist-packages (from lifelines) (3.7.1)
Requirement already satisfied: autograd>=1.5 in /usr/local/lib/python3.10/dist-packages (from lifelines) (1.6.2)
Collecting autograd-gamma>=0.3 (from lifelines)
Downloading autograd-gamma-0.5.0.tar.gz (4.0 kB)
Preparing metadata (setup.py) ... done
Collecting formulaic>=0.2.2 (from lifelines)
Downloading formulaic-0.6.4-py3-none-any.whl (88 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 88.9/88.9 kB 9.8 MB/s eta 0:00:00
Requirement already satisfied: future>=0.15.2 in /usr/local/lib/python3.10/dist-packages (from autograd>=1.5->lifelines) (0.18.3)
Collecting astor>=0.8 (from formulaic>=0.2.2->lifelines)
Downloading astor-0.8.1-py2.py3-none-any.whl (27 kB)
Collecting interface-meta>=1.2.0 (from formulaic>=0.2.2->lifelines)
Downloading interface_meta-1.3.0-py3-none-any.whl (14 kB)
Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from formulaic>=0.2.2->lifelines) (4.5.0)
Requirement already satisfied: wrapt>=1.0 in /usr/local/lib/python3.10/dist-packages (from formulaic>=0.2.2->lifelines) (1.15.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->lifelines) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->lifelines) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->lifelines) (4.42.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->lifelines) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->lifelines) (23.1)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->lifelines) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->lifelines) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.0->lifelines) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.0.0->lifelines) (2023.3.post1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib>=3.0->lifelines) (1.16.0)
Building wheels for collected packages: autograd-gamma
Building wheel for autograd-gamma (setup.py) ... done
Created wheel for autograd-gamma: filename=autograd_gamma-0.5.0-py3-none-any.whl size=4031 sha256=245f880f10529d793223c2f255de7c7a6289191de61b96ad09a29e2aa4bffcf1
Stored in directory: /root/.cache/pip/wheels/25/cc/e0/ef2969164144c899fedb22b338f6703e2b9cf46eeebf254991
Successfully built autograd-gamma
Installing collected packages: interface-meta, astor, autograd-gamma, formulaic, lifelines
Successfully installed astor-0.8.1 autograd-gamma-0.5.0 formulaic-0.6.4 interface-meta-1.3.0 lifelines-0.27.8
Collecting dalex
Downloading dalex-1.6.0.tar.gz (1.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 8.0 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from dalex) (67.7.2)
Requirement already satisfied: pandas>=1.2.5 in /usr/local/lib/python3.10/dist-packages (from dalex) (1.5.3)
Requirement already satisfied: numpy>=1.20.3 in /usr/local/lib/python3.10/dist-packages (from dalex) (1.23.5)
Requirement already satisfied: scipy>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from dalex) (1.11.2)
Requirement already satisfied: plotly>=5.1.0 in /usr/local/lib/python3.10/dist-packages (from dalex) (5.15.0)
Requirement already satisfied: tqdm>=4.61.2 in /usr/local/lib/python3.10/dist-packages (from dalex) (4.66.1)
Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.2.5->dalex) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.2.5->dalex) (2023.3.post1)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly>=5.1.0->dalex) (8.2.3)
Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from plotly>=5.1.0->dalex) (23.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas>=1.2.5->dalex) (1.16.0)
Building wheels for collected packages: dalex
Building wheel for dalex (setup.py) ... done
Created wheel for dalex: filename=dalex-1.6.0-py3-none-any.whl size=1045995 sha256=574570100c629df4765208ad09e70132d2d5fecfdd5bcb83194b20f50a3f9107
Stored in directory: /root/.cache/pip/wheels/c8/45/19/f5810bf7c5ff9a476ebd89bb5b81a18ffcdf93931d17dbb0c1
Successfully built dalex
Installing collected packages: dalex
Successfully installed dalex-1.6.0
We start by defining relative risk (RR) of an individual person with a given vector/tuple of individual attributes ${\bf x}$, e.g., body-mass-index (BMI), systolic blood pressure (SBP), etc.
$$ \mathrm{RR}({\bf x}) = \frac{\mu({\bf x})}{\mu_{\mathrm{ref}}} , $$where $\mu({\bf x})$ is the probability/risk of a health event, e.g., heart attack or mortality, happening over a given fixed time horizon, e.g., one year or 10 years. Relative risks require a reference risk $\mu_{\mathrm{ref}}$, usually the assumed probability/risk of a population or individual with reference attributes ${\bf x}$. We use the following attributes as reference and set the time horizon max_T to 10 years.
max_T = 10
age = 35
gender = "male"
height = 180
weight = 75 # body-mass-index BMI = weight/height^2 = 23.1kg/m^2
tcl_hdl_ratio = 3.5 # the ratio of total cholesterol level (TCL) to high density lipoprotein (HDL) cholesterol level (the "good" cholesterol), values below 3.5 are considered very good
sbp = 120 # systolic blood pressure, unit mmHg
sd_sbp = 10 # the standard deviation of SBP measurements
We initially focus only on BMI and SBP (and age) within these ranges:
bmi_range = np.arange(18, 40, 0.5)
sbp_range = np.arange(110, 150, 1)
age_range = np.arange(25, 85, 1)
Mainly due to data privacy reasons it remains challenging to get public access to large-scale (longitudinal) health datasets. The following data sources are worthwhile mentioning:
As a publicly available starting point to create synthetic datasets, we are using QRISK3, see https://qrisk.org/, a risk calculator to estimate the probability of developing a heart attack or stroke over the next 10 years - we will denote this as CVD risk (cardio-vascular diseases).
We drop several categorial variables in the calculator, e.g., ethnicity, and just work on age, gender, height, weight, TCL/HDL ratio, SBP, and standard deviation of SBP (the variability from doctor visit-to-visit measurements).
def qrisk3(age, gender, height, weight, tcl_hdl_ratio, sbp, sd_sbp):
age = np.maximum(age, 25 - 10*np.tanh((25 - age)/10)) # extension to below age 25
if gender == "female":
age_1 = np.power(age/10, -2)
age_2 = age/10
bmi = weight/height/height*10000
bmi_1 = np.power(bmi/10,-2)
bmi_2 = np.power(bmi/10,-2)*np.log(bmi/10)
age_1 = age_1 - 0.053274843841791
age_2 = age_2 - 4.332503318786621
bmi_1 = bmi_1 - 0.154946178197861
bmi_2 = bmi_2 - 0.144462317228317
tcl_hdl_ratio = tcl_hdl_ratio - 3.476326465606690
sbp = sbp - 123.130012512207030
sd_sbp = sd_sbp - 9.002537727355957
# additive parts of effects
a = age_1 * -8.1388109247726188 + age_2 * 0.79733376689699098 + bmi_1 * 0.29236092275460052 + bmi_2 * -4.1513300213837665 + tcl_hdl_ratio * 0.15338035820802554 + sbp * 0.013131488407103424 + sd_sbp * 0.0078894541014586095
# interaction parts
a += age_1 * bmi_1 * 23.802623412141742 + age_1 * bmi_2 * -71.184947692087007 + age_1 * sbp * 0.034131842338615485 + age_2 * bmi_1 * 0.52369958933664429 + age_2 * bmi_2 * 0.045744190122323759 + age_2 * sbp * -0.0015082501423272358
return 1 - np.power(0.988876402378082, np.exp(a))
else:
age_1 = np.power(age/10, -1)
age_2 = np.power(age/10, 3)
bmi = weight/height/height*10000
bmi_1 = np.power(bmi/10,-2)
bmi_2 = np.power(bmi/10,-2)*np.log(bmi/10)
age_1 = age_1 - 0.234766781330109
age_2 = age_2 - 77.284080505371094
bmi_1 = bmi_1 - 0.149176135659218
bmi_2 = bmi_2 - 0.141913309693336
tcl_hdl_ratio = tcl_hdl_ratio - 4.300998687744141
sbp = sbp - 128.571578979492190
sd_sbp = sd_sbp - 8.756621360778809
# additive parts of effects
a = age_1 * -17.839781666005575 + age_2 * 0.0022964880605765492 + bmi_1 * 2.4562776660536358 + bmi_2 * -8.3011122314711354 + tcl_hdl_ratio * 0.17340196856327111 + sbp * 0.012910126542553305 + sd_sbp * 0.010251914291290456
# interaction parts
a += age_1 * bmi_1 * 31.004952956033886 + age_1 * bmi_2 * -111.29157184391643 + age_1 * sbp * 0.018858524469865853 + age_2 * bmi_1 * 0.0050380102356322029 + age_2 * bmi_2 * -0.013074483002524319 + age_2 * sbp * -0.00001271874191588457
return 1 - np.power(0.977268040180206, np.exp(a))
Set the reference risk $\mu_{\mathrm{ref}}$.
mu_ref = qrisk3(age, gender, height, weight, tcl_hdl_ratio, sbp, sd_sbp)
print("Reference CVD risk over 10 years: {:.2f}%".format(mu_ref*100))
Reference CVD risk over 10 years: 0.63%
CVD risk - just like mortality - strongly depends on age. As a rule of thumb: 10% increase each year.
For mortality, this rule is known as the Gompertz-Makeham model/approximation, but the approximation is only reasonable from around ages 30 to 80.
If we change only the age of the reference person, we obtain a plot showing the exponential increase of CVD risk by age.
mu = [qrisk3(x, gender, height, weight, tcl_hdl_ratio, sbp, sd_sbp) for x in age_range]
sn.set()
plt.figure(figsize=(10,6))
plt.plot(age_range, mu/mu_ref)
plt.xlabel('Age')
plt.ylabel('10 year relative CVD risk')
plt.title('Relative CVD risk with age 35 as reference')
plt.show()
CVD risk - just like mortality - depends on BMI in a J-shaped way.
mu = [qrisk3(age, gender, height, x*(height/100)**2, tcl_hdl_ratio, sbp, sd_sbp) for x in bmi_range]
plt.figure(figsize=(10,6))
plt.plot(bmi_range, mu/mu_ref)
plt.xlabel('BMI')
plt.ylabel('10 year relative CVD risk')
plt.title('Relative CVD risk with BMI 23.1 as reference')
plt.show()
However, if we change only BMI this might not be what you are actually interested in, which could be one of the following questions, starting with the one that corresponds to the plot above:
Related literature on health predictions:
Let's prepare working on questions 2. and 3. by defining
$\begin{pmatrix} \mathrm{SBP} \\ \log(\mathrm{BMI})\end{pmatrix} \sim N\left( \begin{pmatrix} 125 \\ 3.2\end{pmatrix}, \begin{pmatrix} 15^2 & 15 \cdot 0.25\rho\\ 15 \cdot 0.25 \rho & 0.25^2\end{pmatrix} \right)$, with a correlation coefficient $\rho = 0.25$.
So, we assume SBP is normally distributed with mean 125mmHg and standard deviation 15mmHg, and BMI is log-normally distributed with mean $\exp(3.2 + 0.25^2/2)\approx 25.3\mathrm{kg}/\mathrm{m}^2$ and standard deviation $\exp(3.2 + 0.25^2/2)\sqrt{\exp(0.25^2)-1}\approx 6.4\mathrm{kg}/\mathrm{m}^2$.
| Country | Mean BMI females | Mean BMI males |
|---|---|---|
| Samoa | 33.5 | 29.9 |
| USA | 28.8 | 28.8 |
| UK | 27.1 | 27.5 |
| Germany | 25.6 | 27.0 |
| Italy | 25.2 | 26.8 |
| France | 24.6 | 26.1 |
| Switzerland | 23.8 | 26.7 |
| Japan | 21.7 | 23.6 |
Source: https://en.wikipedia.org/wiki/List_of_sovereign_states_by_body_mass_index
Period life expectancy at birth in 2019:
| Country | Life exp. females | Life exp. females |
|---|---|---|
| Samoa | 75.5 | 71.3 |
| USA | 81.5 | 76.5 |
| UK | 83.3 | 79.6 |
| Germany | 83.5 | 78.8 |
| Italy | 85.4 | 81.1 |
| France | 85.6 | 79.8 |
| Switzerland | 85.6 | 81.9 |
| Japan | 87.4 | 81.4 |
Source (apart from Samoa): https://mortality.org/
Time to simulate the first synthetic dataset $D_1$:
persons = 1_000_000
D1 = pd.DataFrame(index = range(persons), columns = ['BMI', 'SBP', 'mu', 'event', 'T'])
np.random.seed(0)
sbp_bmi = np.random.multivariate_normal([125, 3.2], np.matmul(np.matmul(np.diag([15, 0.25]), np.array([[1, 0.25], [0.25, 1]])), np.diag([15, 0.25])), size=persons)
D1.SBP = sbp_bmi[:,0]
D1.BMI = np.exp(sbp_bmi[:,1])
D1.mu = qrisk3(age, gender, height, np.exp(sbp_bmi[:,1])*(height/100)**2, tcl_hdl_ratio, sbp_bmi[:,0], sd_sbp)
D1.event = np.where(np.random.uniform(0, 1, persons) < D1.mu, 1, 0)
D1 = D1[(D1.SBP>=min(sbp_range)) & (D1.SBP<=max(sbp_range)) & (D1.BMI>=min(bmi_range)) & (D1.BMI<=max(bmi_range))].reset_index()
plt.figure(figsize=(10,6))
plt.hist2d(D1.BMI, D1.SBP, bins=(80, 80), cmap=plt.cm.inferno)
plt.plot([bmi_range[0], bmi_range[len(bmi_range)-1]], [125, 125], 'r')
plt.plot(bmi_range, (np.log(bmi_range)-3.2)*0.25*15+125, 'g')
plt.xlabel('BMI')
plt.ylabel('SBP')
plt.title('Density plot')
plt.colorbar()
plt.show()
print("Number of events: ", np.sum(D1.event))
print("Number of individuals: ", len(D1))
D1.head()
Number of events: 5363 Number of individuals: 688503
| index | BMI | SBP | mu | event | T | |
|---|---|---|---|---|---|---|
| 0 | 3 | 22.286183 | 110.748826 | 0.005485 | 0 | NaN |
| 1 | 4 | 27.271473 | 126.547869 | 0.007587 | 0 | NaN |
| 2 | 5 | 34.570903 | 122.837879 | 0.008999 | 0 | NaN |
| 3 | 6 | 24.091885 | 113.584311 | 0.005833 | 0 | NaN |
| 4 | 7 | 25.868334 | 118.341715 | 0.006495 | 0 | NaN |
Coming back to the first 3 questions:
mu_marginal = [qrisk3(age, gender, height, x*(height/100)**2, tcl_hdl_ratio, 125, sd_sbp) for x in bmi_range]
mu_conditional = [qrisk3(age, gender, height, x*(height/100)**2, tcl_hdl_ratio, (np.log(x)-3.2)*0.25*15+125, sd_sbp) for x in bmi_range]
expected_mu = np.zeros(len(bmi_range)-1)
for j in np.arange(0, len(bmi_range) - 1):
expected_mu[j] = np.mean(D1.mu[(D1.BMI >= bmi_range[j]) & (D1.mu < bmi_range[j+1])])
plt.figure(figsize=(10,6))
plt.plot(bmi_range, mu_marginal/mu_ref, 'r', bmi_range, mu_conditional/mu_ref, 'g', bmi_range[0:len(bmi_range)-1], expected_mu/mu_ref, 'b')
plt.xlabel('BMI')
plt.ylabel('10 year relative CVD incidence risk')
plt.legend(["1. SBP = 125, independent of BMI", "2. SBP set to conditional expectation wrt BMI", "3. Expected risk conditioned on BMI"])
plt.show()
Question: why is the blue curve so different? Let's add a dimension to see...
mx, my = np.meshgrid(sbp_range, bmi_range)
mu = qrisk3(age, gender, height, my*(height/100)**2, tcl_hdl_ratio, mx, sd_sbp)
fig = go.Figure(data=[go.Surface(z=mu/mu_ref, x=sbp_range, y=bmi_range)])
fig.update_traces(contours_z=dict(show=True, project_z=True), colorscale='Jet')
fig.update_layout(autosize=True, scene = dict(xaxis_title='SBP', yaxis_title='BMI', zaxis_title='Ground truth 10 year relative CVD incidence risk'), width=1000, height=1000)
fig.show()
Answer: due to the convexity along SBP (and also due to trimming SBP into plausible ranges in $D_1$).
Thinking about epidemiology in terms of
can be very useful.
For all models, we will work on min-max scaled covariates and assess 3 performance metrics (ROC AUC, MSE with respect to log ground truth, logistic deviance) on unseen test data taken from $D_1$.
D1["BMI_scaled"] = (D1.BMI-min(bmi_range))/(max(bmi_range)-min(bmi_range))
D1["SBP_scaled"] = (D1.SBP-min(sbp_range))/(max(sbp_range)-min(sbp_range))
D1_train, D1_test = train_test_split(D1, test_size = 0.15, random_state = 0)
We use the statsmodels package to run a logistic regression/GLM (generalized linear model) on the train subset of $D_1$.
The functional form is given by
$$ \mathrm{logit}(\mu_{\text{log reg}}({\bf x})) = \beta_0 + \beta_1\mathrm{SBP\_scaled}+ \beta_2\mathrm{BMI\_scaled} + \beta_3\mathrm{BMI\_scaled}^2~, $$where $\mathrm{logit}(x) := \log(x/(1-x))$ is called the link function of the GLM.
We define $y_j:=\mathrm{event}_j$, where $y_j = 1$ means that the $j$-th event in the train subset of $D_1$ happens, and $\widehat{y}_j = \mu_{\text{log reg}}({\bf x}_j)$. The (total) deviance function that gets minimized during the fitting of the logistic regression (equivalent to maximizing log-likelihood) is given by
$$ D(y, {\widehat{y}}) = -2\sum_{j=1}^{|D_{1,\mathrm{train}}|} y_j\log(\widehat{y}_j)+(1-y_j)\log(1-\widehat{y}_j)~. $$It is proportional to the binary cross-entropy and the log Loss.
log_reg_D1 = sm.logit(formula='event ~ SBP_scaled + BMI_scaled + I(BMI_scaled**2)', data=D1_train).fit()
print(log_reg_D1.summary())
mx, my = np.meshgrid((sbp_range-min(sbp_range))/(max(sbp_range)-min(sbp_range)), (bmi_range-min(bmi_range))/(max(bmi_range)-min(bmi_range)))
bmi_sbp = pd.DataFrame({'SBP_scaled': mx.flatten(), 'BMI_scaled': my.flatten()})
pred = log_reg_D1.predict(bmi_sbp).values.reshape(len(bmi_range), len(sbp_range))
fig = go.Figure(data=[go.Surface(z=pred/mu_ref, x=sbp_range, y=bmi_range)])
fig.update_traces(contours_z=dict(show=True, project_z=True), colorscale='Jet')
fig.update_layout(autosize=True, scene = dict(xaxis_title='SBP', yaxis_title='BMI', zaxis_title='Fitted 10 year relative CVD incidence risk'), width=1000, height=1000)
fig.show()
fig = go.Figure(data=[go.Surface(z=mu/mu_ref, x=sbp_range, y=bmi_range)])
fig.update_traces(contours_z=dict(show=True, project_z=True), colorscale='Jet')
fig.update_layout(autosize=True, scene = dict(xaxis_title='SBP', yaxis_title='BMI', zaxis_title='Ground truth 10 year relative CVD incidence risk'), width=1000, height=1000)
fig.show()
pred = log_reg_D1.predict(D1_test)
auc_log_reg_D1 = roc_auc_score(D1_test.event, pred)
mse_log_reg_D1 = np.mean(np.square(np.log(D1_test.mu) - np.log(pred)))
dev_log_reg_D1 = -2*np.sum(D1_test.event*logit(pred) + np.log(1-pred))
kld_log_reg_D1 = np.sum(D1_test.event*np.log(D1_test.event/pred))
print("ROC AUC (test): ", auc_log_reg_D1)
print("MSE wrt log ground truth: ", mse_log_reg_D1)
print("Logistic deviance: ", dev_log_reg_D1)
print("Kullback-Leibler divergence: ", kld_log_reg_D1)
prob_true, prob_pred = calibration_curve(D1_test.event, pred, n_bins=20, strategy='quantile')
plt.figure()
plt.plot(prob_pred, prob_true, marker='o')
plt.plot([0.004, 0.014], [0.004, 0.014], 'k')
plt.axis([0.004, 0.014, 0.004, 0.014])
plt.title('Calibration plot logistic regression')
plt.xlabel('Predicted probability')
plt.ylabel('True probability in each bin')
plt.show()
Optimization terminated successfully.
Current function value: 0.045549
Iterations 9
Logit Regression Results
==============================================================================
Dep. Variable: event No. Observations: 585227
Model: Logit Df Residuals: 585223
Method: MLE Df Model: 3
Date: Wed, 20 Sep 2023 Pseudo R-squ.: 0.003406
Time: 17:31:05 Log-Likelihood: -26656.
converged: True LL-Null: -26747.
Covariance Type: nonrobust LLR p-value: 2.941e-39
======================================================================================
coef std err z P>|z| [0.025 0.975]
--------------------------------------------------------------------------------------
Intercept -5.1893 0.051 -102.681 0.000 -5.288 -5.090
SBP_scaled 0.4928 0.058 8.425 0.000 0.378 0.607
BMI_scaled -0.0933 0.219 -0.427 0.670 -0.522 0.335
I(BMI_scaled ** 2) 0.7182 0.230 3.126 0.002 0.268 1.168
======================================================================================
ROC AUC (test): 0.5617061723108522 MSE wrt log ground truth: 0.0015757922217920983 Logistic deviance: 9223.880067951093 Kullback-Leibler divergence: 3807.9559497679784
/usr/local/lib/python3.10/dist-packages/pandas/core/arraylike.py:402: RuntimeWarning: divide by zero encountered in log
Conclusions
Interpretation of model coefficients $\beta_j$
First, define the odds of an event happening given $\bf x$ by
$$ \mathrm{odds}(y = 1 | {\bf x}) = \frac{P(y=1| {\bf x})}{P(y=0| {\bf x})} = \frac{P(y=1| {\bf x})}{1-P(y=1| {\bf x})}~. $$The log odds under the logistic regression model $\mu_{\text{log reg}}({\bf x})$ then are given by
$$ \log(\mathrm{odds}(y = 1 | {\bf x})) = \mathrm{logit}(\mathrm{P}(y = 1 | {\bf x})) = \mathrm{logit}(\mu_{\text{log reg}}({\bf x})) = \beta_0 + \beta_1\mathrm{SBP\_scaled}+ \beta_2\mathrm{BMI\_scaled} + \beta_3\mathrm{BMI\_scaled}^2~. $$Increasing an $x_j$ by 1 compared to $\bf x$ leads to the following odds ratio (OR)
$$ \frac{\mathrm{odds}(y=1|\text{increase $x_j$ by 1 compared to $\bf x$})}{\mathrm{odds}(y=1| {\bf x})} = \exp(\beta_j)~. $$We use the lifelines package to run a Cox regression/Cox proportional hazards model on the train subset of $D_1$.
Cox regression is a semi-parametric model with one part being a baseline hazard rate $h_0(t)$ depending only on time $t$, but not on $\bf x$, and a time-independent part, but depending on $\bf x$. The full functional form of the hazard rate is given by
$$ h(t | {\bf x}) = h_0(t)\exp\left(\beta_1\mathrm{SBP\_scaled}+ \beta_2\mathrm{BMI\_scaled} + \beta_3\mathrm{BMI\_scaled}^2\right)~. $$What is a hazard rate?
We defined the probability of a CVD event happening over a time horizon of 10 years by $\mu({\bf x})$. A hazard rate is an instantaneous, time-continuous version of this and is comparable to a continuous compound interest rate to evaluate zero-coupon bonds. In general, hazard rates are defined by
$$ h(t) := \lim_{\tau\to 0}\frac{P(t\leq T < t+\tau)}{\tau P(T\geq t)}~. $$Interpretation of model coefficients $\beta_j$
For logistic regression we could interpret the model coefficients via odds ratios (ORs). For Cox regression, the model coefficients can be interpreted by hazard ratios (HRs). Increasing an $x_j$ by 1 compared to $\bf x$ leads to a hazard ratio of
$$ \frac{h(t|\text{increase $x_j$ by 1 compared to $\bf x$})}{h(t|{\bf x})} = \exp(\beta_j)~. $$Note that the hazard ratio is constant over time $t$, which is the reason why Cox regression is called a proportioanal hazard model.
The survival probability (i.e., the event is not happening until time $t$) is given by
$$ S(t|{\bf x}) = \exp\left(-\int_{\tau = 0}^t h(t|{\bf x})\ d\tau\right)~, $$and we obtain the probability of an event happening by
$$ \mu_{\text{Cox}}({\bf x}) = 1 - S(10|{\bf x})~, $$which we can compare against logistic regression $\mu_{\text{log reg}}({\bf x})$.
While logistic regression only considered the binary event variable as target, Cox regression also makes use of the time to event, which provides additional information. Cox regression is essentially fitting the survival curves $S(t|{\bf x})$ that contain more information than the binary event variable. We thus need to add the time to event to our dataset $D_1$.
We define a function draw_T that draws a time of the event conditioned that the event happens over the 10 year time horizon. We use the Gompertz-Makeham assumption that risk increases by 10% annually, which corresponds to an instantaneous, time-continuous rate of around 0.095, i.e., $\exp(0.095)\approx 1.1$. Overall, the CDF $F(t)$ from which draw_T draws is given by
where $\alpha$ is chosen such that
$$ 1-\exp\left(-\int_{\tau = 0}^{t} \alpha\exp(0.0095\tau)\ d\tau \right) = \mu({\bf x})~. $$def draw_T(mu):
beta = 0.095 # assuming increase of risk of around 10% per year
# bisection approach to obtain alpha
alpha_lb = 0
alpha_ub = 1
while (alpha_ub - alpha_lb > 0.000005):
if np.exp((alpha_ub + alpha_lb)/2*(1-np.exp(max_T*beta))/beta) > 1 - mu:
alpha_lb = (alpha_ub + alpha_lb)/2
else:
alpha_ub = (alpha_ub + alpha_lb)/2
alpha = (alpha_ub + alpha_lb)/2
t_lb = 0
t_ub = max_T
r = np.random.uniform(0, 1)
# inversion of 1 - survival curve by bisection
while (t_ub - t_lb > 0.01):
if (1 - np.exp(alpha*(1 - np.exp((t_ub + t_lb)/2*beta))/beta))/mu < r:
t_lb = (t_ub + t_lb)/2
else:
t_ub = (t_ub + t_lb)/2
return (t_ub + t_lb)/2
plt.hist([draw_T(0.01) for x in range(50000)] , bins=40)
plt.title("Histogram of 50k draws with low event probability 0.01")
plt.xlabel("Time/year")
plt.show()
plt.hist([draw_T(0.8) for x in range(50000)] , bins=40)
plt.title("Histogram of 50k draws with high event probability 0.8")
plt.xlabel("Time/year")
plt.show()
Apply draw_T to $D_{1, \mathrm{train}}$ and $D_{1, \mathrm{test}}$:
np.random.seed(0)
D1_train["T"] = D1_train.apply(lambda row : max_T if row.event == 0 else draw_T(row.mu), axis=1)
D1_test["T"] = D1_test.apply(lambda row : max_T if row.event == 0 else draw_T(row.mu), axis=1)
print("Number of person years: ", int(np.sum(D1_train["T"]) + np.sum(D1_test["T"])))
Number of person years: 6862185
Ready for Cox regression... but let's first look at the estimated overall survival curve, using the Kaplan-Meier estimator:
kmf = ll.KaplanMeierFitter()
kmf.fit(D1_train['T'], event_observed=D1_train['event'])
kmf.plot_survival_function();
Finally, Cox regression:
cph_D1 = ll.CoxPHFitter()
cph_D1.fit(D1_train, 'T', event_col='event', formula="SBP_scaled + BMI_scaled + I(BMI_scaled**2)")
cph_D1.print_summary()
cph_D1.plot();
survival = cph_D1.predict_survival_function(D1_test)
pred = 1 - np.squeeze(np.array(survival.iloc[-1,:]))
auc_cox_reg_D1 = roc_auc_score(D1_test.event, pred)
mse_cox_reg_D1 = np.mean(np.square(np.log(D1_test.mu) - np.log(pred)))
dev_cox_reg_D1 = -2*np.sum(D1_test.event*logit(pred) + np.log(1-pred))
kld_cox_reg_D1 = np.sum(D1_test.event*np.log(D1_test.event/pred))
print("ROC AUC (test): ", auc_cox_reg_D1)
print("MSE wrt log ground truth: ", mse_cox_reg_D1)
print("Logistic deviance: ", dev_cox_reg_D1)
print("Kullback-Leibler divergence: ", kld_cox_reg_D1)
prob_true, prob_pred = calibration_curve(D1_test.event, pred, n_bins=20, strategy='quantile')
plt.figure()
plt.plot(prob_pred, prob_true, marker='o')
plt.plot([0.004, 0.014], [0.004, 0.014], 'k')
plt.axis([0.004, 0.014, 0.004, 0.014])
plt.title('Calibration plot Cox regression')
plt.xlabel('Predicted probability')
plt.ylabel('True probability in each bin')
plt.show()
mx, my = np.meshgrid((sbp_range-min(sbp_range))/(max(sbp_range)-min(sbp_range)), (bmi_range-min(bmi_range))/(max(bmi_range)-min(bmi_range)))
bmi_sbp = pd.DataFrame({'SBP_scaled': mx.flatten(), 'BMI_scaled': my.flatten()})
pred = (1 - np.array(cph_D1.predict_survival_function(bmi_sbp))[-1,:]).reshape(len(bmi_range), len(sbp_range))
fig = go.Figure(data=[go.Surface(z=pred/mu_ref, x=sbp_range, y=bmi_range)])
fig.update_traces(contours_z=dict(show=True, project_z=True), colorscale='Jet')
fig.update_layout(autosize=True, scene = dict(xaxis_title='SBP', yaxis_title='BMI', zaxis_title='Fitted 10 year relative CVD incidence risk'), width=1000, height=1000)
fig.show()
| model | lifelines.CoxPHFitter |
|---|---|
| duration col | 'T' |
| event col | 'event' |
| baseline estimation | breslow |
| number of observations | 585227 |
| number of events observed | 4574 |
| partial log-likelihood | -60632.64 |
| time fit was run | 2023-09-20 17:31:52 UTC |
| coef | exp(coef) | se(coef) | coef lower 95% | coef upper 95% | exp(coef) lower 95% | exp(coef) upper 95% | cmp to | z | p | -log2(p) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| SBP_scaled | 0.49 | 1.63 | 0.06 | 0.38 | 0.60 | 1.46 | 1.83 | 0.00 | 8.42 | <0.005 | 54.56 |
| BMI_scaled | -0.09 | 0.91 | 0.22 | -0.52 | 0.34 | 0.60 | 1.40 | 0.00 | -0.42 | 0.68 | 0.57 |
| I(BMI_scaled**2) | 0.71 | 2.04 | 0.23 | 0.26 | 1.16 | 1.30 | 3.19 | 0.00 | 3.12 | <0.005 | 9.11 |
| Concordance | 0.56 |
|---|---|
| Partial AIC | 121271.27 |
| log-likelihood ratio test | 182.09 on 3 df |
| -log2(p) of ll-ratio test | 127.92 |
ROC AUC (test): 0.561706568045608 MSE wrt log ground truth: 0.0015788497360106327 Logistic deviance: 9223.87818207559 Kullback-Leibler divergence: 3807.969519644536
Survival curves stratified by (scaled) SBP:
cph_D1.plot_partial_effects_on_outcome(covariates='SBP_scaled', values=np.round((np.array([100, 110, 120, 130, 140, 150])-np.min(sbp_range))/(np.max(sbp_range)-np.min(sbp_range)), 2));
Model performances on $D_{1, \mathrm{test}}$ overview. No substantial improvement of Cox regression over logistic regression so far.
print("Logistic regression")
print("===================")
print("ROC AUC (test): ", auc_log_reg_D1)
print("MSE wrt log ground truth: ", mse_log_reg_D1)
print("Logistic deviance: ", dev_log_reg_D1)
print("Kullback-Leibler divergence: ", kld_log_reg_D1)
print("\nCox regression")
print("==============")
print("ROC AUC (test): ", auc_cox_reg_D1)
print("MSE wrt log ground truth: ", mse_cox_reg_D1)
print("Logistic deviance: ", dev_cox_reg_D1)
print("Kullback-Leibler divergence: ", kld_cox_reg_D1)
Logistic regression =================== ROC AUC (test): 0.5617061723108522 MSE wrt log ground truth: 0.0015757922217920983 Logistic deviance: 9223.880067951093 Kullback-Leibler divergence: 3807.9559497679784 Cox regression ============== ROC AUC (test): 0.561706568045608 MSE wrt log ground truth: 0.0015788497360106327 Logistic deviance: 9223.87818207559 Kullback-Leibler divergence: 3807.969519644536
First, we show the analogy between neural nets and GLMs by defining a shall neural net without any hidden layers. We have to add the I(BMI_scaled**2) of the formula used for logistic regression to the data.
D1_train['BMI_scaled_squared'] = D1_train.BMI_scaled**2
D1_test['BMI_scaled_squared'] = D1_test.BMI_scaled**2
We use the keras package integrated into tensorflow.
shallow_neural_net_D1 = Sequential([
Input(shape=(3,)),
Dense(1, activation = "sigmoid", kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0))
])
The functional form of this neural net is given by
$$ \mathrm{logit}(\mu_{\text{shallow neural net}}({\bf x})) = \beta_0 + \beta_1\mathrm{SBP\_scaled}+ \beta_2\mathrm{BMI\_scaled} + \beta_3\mathrm{BMI\_scaled}^2~, $$and the loss function that gets minimized by binary cross entropy is
$$ D(y, {\widehat{y}}) = -2\sum_{j=1}^{|D_{1,\mathrm{train}}|} y_j\log(\widehat{y}_j)+(1-y_j)\log(1-\widehat{y}_j)~. $$Looks familiar - exactly the same as for logistic regression. However, while logistic regression essentially finds the exact solution (or one of the exact solutions) by iteratively reweighted least squares, the neural net is designed in a much more general way with much slower convergence - if at all - to the global minimum. We will see later in the more complex dataset $D_2$ that one usually spends many trainable parameters and then applies early stopping methods when the model starts to overfit the training data.
Fitting/training the shallow neural net:
X_train = D1_train[['SBP_scaled', 'BMI_scaled', 'BMI_scaled_squared']]
y_train = D1_train['event']
shallow_neural_net_D1.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.002), loss = 'binary_crossentropy')
history = shallow_neural_net_D1.fit(X_train, y_train, batch_size = 256, epochs = 10, shuffle = True, validation_split = 0.2, verbose = 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()
Epoch 1/10 1829/1829 [==============================] - 5s 2ms/step - loss: 0.1908 - val_loss: 0.0689 Epoch 2/10 1829/1829 [==============================] - 4s 2ms/step - loss: 0.0588 - val_loss: 0.0506 Epoch 3/10 1829/1829 [==============================] - 4s 2ms/step - loss: 0.0515 - val_loss: 0.0478 Epoch 4/10 1829/1829 [==============================] - 3s 2ms/step - loss: 0.0495 - val_loss: 0.0460 Epoch 5/10 1829/1829 [==============================] - 3s 2ms/step - loss: 0.0478 - val_loss: 0.0446 Epoch 6/10 1829/1829 [==============================] - 3s 2ms/step - loss: 0.0468 - val_loss: 0.0440 Epoch 7/10 1829/1829 [==============================] - 4s 2ms/step - loss: 0.0464 - val_loss: 0.0437 Epoch 8/10 1829/1829 [==============================] - 4s 2ms/step - loss: 0.0462 - val_loss: 0.0436 Epoch 9/10 1829/1829 [==============================] - 3s 2ms/step - loss: 0.0461 - val_loss: 0.0436 Epoch 10/10 1829/1829 [==============================] - 5s 3ms/step - loss: 0.0461 - val_loss: 0.0435
Inspecting the coefficients/weights of the shallow neural net and comparing them to the coefficients of logistic regression:
print("Shallow neural net intercept: ", shallow_neural_net_D1.layers[0].get_weights()[1])
print("Shallow neural net model weights: ", shallow_neural_net_D1.layers[0].get_weights()[0])
print("Logistic regression coefficients: ")
print(log_reg_D1.params)
Shallow neural net intercept: [-4.9802976] Shallow neural net model weights: [[ 0.22214867] [-0.7075379 ] [ 1.3153745 ]] Logistic regression coefficients: Intercept -5.189281 SBP_scaled 0.492821 BMI_scaled -0.093293 I(BMI_scaled ** 2) 0.718205 dtype: float64
Model performances on $D_{1, \mathrm{test}}$ overview.
X_test = D1_test[['SBP_scaled', 'BMI_scaled', 'BMI_scaled_squared']]
y_test = D1_test['event']
pred = shallow_neural_net_D1.predict(X_test).flatten()
auc_shallow_neural_net_D1 = roc_auc_score(D1_test.event, pred)
mse_shallow_neural_net_D1 = np.mean(np.square(np.log(D1_test.mu) - np.log(pred)))
dev_shallow_neural_net_D1 = -2*np.sum(D1_test.event*logit(pred) + np.log(1-pred))
kld_shallow_neural_net_D1 = np.sum(D1_test.event*np.log(D1_test.event/pred))
print("Logistic regression")
print("===================")
print("ROC AUC (test): ", auc_log_reg_D1)
print("MSE wrt log ground truth: ", mse_log_reg_D1)
print("Logistic deviance: ", dev_log_reg_D1)
print("Kullback-Leibler divergence: ", kld_log_reg_D1)
print("\nCox regression")
print("==============")
print("ROC AUC (test): ", auc_cox_reg_D1)
print("MSE wrt log ground truth: ", mse_cox_reg_D1)
print("Logistic deviance: ", dev_cox_reg_D1)
print("Kullback-Leibler divergence: ", kld_cox_reg_D1)
print("\nShallow neural net")
print("==================")
print("ROC AUC (test): ", auc_shallow_neural_net_D1)
print("MSE wrt log ground truth: ", mse_shallow_neural_net_D1)
print("Logistic deviance: ", dev_shallow_neural_net_D1)
print("Kullback-Leibler divergence: ", kld_shallow_neural_net_D1)
3228/3228 [==============================] - 4s 1ms/step Logistic regression =================== ROC AUC (test): 0.5617061723108522 MSE wrt log ground truth: 0.0015757922217920983 Logistic deviance: 9223.880067951093 Kullback-Leibler divergence: 3807.9559497679784 Cox regression ============== ROC AUC (test): 0.561706568045608 MSE wrt log ground truth: 0.0015788497360106327 Logistic deviance: 9223.87818207559 Kullback-Leibler divergence: 3807.969519644536 Shallow neural net ================== ROC AUC (test): 0.5582048039404497 MSE wrt log ground truth: 0.01092150969267656 Logistic deviance: 9232.214921846986 Kullback-Leibler divergence: 3839.3764256154614
The shallow neural does not outperform the traditional models in any of the performance metrics.
Let's create a more complex, deeper neural net with 2 hidden layers:
deep_neural_net_D1 = Sequential([
Input(shape=(2,)),
Dense(256, kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0)),
Activation(tf.keras.activations.relu),
Dense(16, kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0)),
Activation(tf.keras.activations.relu),
Dense(1, activation = "sigmoid", kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0))
])
print(deep_neural_net_D1.summary())
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 256) 768
activation (Activation) (None, 256) 0
dense_2 (Dense) (None, 16) 4112
activation_1 (Activation) (None, 16) 0
dense_3 (Dense) (None, 1) 17
=================================================================
Total params: 4897 (19.13 KB)
Trainable params: 4897 (19.13 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None
How does the functional form of the neural net look like?
Input layer: (scaled) SBP and (scaled) BMI.
Each of the 256 neurons in the next layer, the first hidden layer, takes SBP and BMI as input and calculates $z_j^{(1)}$
$$ z_j^{(1)} = \phi^{(1)}\left( \beta_{0,j}^{(1)} + \beta_{1,j}^{(1)}\mathrm{SBP\_scaled}+ \beta_{2,j}^{(1)}\mathrm{BMI\_scaled}\right)~\text{, for $1\leq j\leq 256$,} $$where $\phi^{(1)}(x) = \mathrm{ReLU}(x) = x\cdot 1_{x\geq 0}$.
Each of the 16 neurons in the next hidden layer, takes the 256 outputs $z_j^{(1)}$ of the previous layers as input and calculates $z_j^{(2)}$
$$ z_j^{(2)} = \phi^{(2)}\left( \beta_{0,j}^{(2)} + \beta_{1,j}^{(2)}z_1^{(1)}+ \cdots +\beta_{256,j}^{(2)}z_{256}^{(1)}\right)~\text{, for $1\leq j\leq 16$,} $$where also $\phi^{(2)} = \mathrm{ReLU}$.
Finally, the last layer calculates the output
$$ \mu_{\text{deep neural net}}({\bf x}) = \phi^{(3)}\left( \beta_{0,1}^{(3)} + \beta_{1,1}^{(3)}z_1^{(2)}+ \cdots +\beta_{16,1}^{(3)}z_{16}^{(2)}\right)~, $$where $\phi^{(3)}(x) = \mathrm{logit}^{-1}(x) = 1/(1+\exp(-x))$.
The number of parameters/coefficients/weights is $256+2\cdot 256=768$ in the first hidden layer, $16+256\cdot 16 = 4112$ in the second hidden layer and $17$ parameters in the last layer, 4897 parameters in total.
X_train = D1_train[['SBP_scaled', 'BMI_scaled']]
y_train = D1_train['event']
tf.random.set_seed(0)
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.005, decay_steps=1000, decay_rate=0.99)
opt = tf.keras.optimizers.Adam(learning_rate = lr_schedule)
deep_neural_net_D1.compile(optimizer = opt, loss = 'binary_crossentropy')
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience = 10, restore_best_weights=True)
history = deep_neural_net_D1.fit(X_train, y_train, batch_size = 256, epochs = 20, shuffle = True, validation_split = 0.20, verbose = 1, callbacks=[callback])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()
Epoch 1/20 1829/1829 [==============================] - 7s 3ms/step - loss: 0.0494 - val_loss: 0.0438 Epoch 2/20 1829/1829 [==============================] - 5s 3ms/step - loss: 0.0465 - val_loss: 0.0436 Epoch 3/20 1829/1829 [==============================] - 5s 3ms/step - loss: 0.0465 - val_loss: 0.0438 Epoch 4/20 1829/1829 [==============================] - 5s 3ms/step - loss: 0.0464 - val_loss: 0.0435 Epoch 5/20 1829/1829 [==============================] - 4s 2ms/step - loss: 0.0463 - val_loss: 0.0441 Epoch 6/20 1829/1829 [==============================] - 6s 3ms/step - loss: 0.0463 - val_loss: 0.0435 Epoch 7/20 1829/1829 [==============================] - 5s 3ms/step - loss: 0.0463 - val_loss: 0.0436 Epoch 8/20 1829/1829 [==============================] - 5s 2ms/step - loss: 0.0463 - val_loss: 0.0436 Epoch 9/20 1829/1829 [==============================] - 6s 3ms/step - loss: 0.0462 - val_loss: 0.0435 Epoch 10/20 1829/1829 [==============================] - 4s 2ms/step - loss: 0.0462 - val_loss: 0.0435 Epoch 11/20 1829/1829 [==============================] - 6s 3ms/step - loss: 0.0462 - val_loss: 0.0435 Epoch 12/20 1829/1829 [==============================] - 5s 3ms/step - loss: 0.0462 - val_loss: 0.0435 Epoch 13/20 1829/1829 [==============================] - 5s 3ms/step - loss: 0.0462 - val_loss: 0.0436 Epoch 14/20 1829/1829 [==============================] - 6s 3ms/step - loss: 0.0462 - val_loss: 0.0438 Epoch 15/20 1829/1829 [==============================] - 5s 2ms/step - loss: 0.0462 - val_loss: 0.0436 Epoch 16/20 1829/1829 [==============================] - 6s 3ms/step - loss: 0.0462 - val_loss: 0.0435 Epoch 17/20 1829/1829 [==============================] - 5s 3ms/step - loss: 0.0462 - val_loss: 0.0435 Epoch 18/20 1829/1829 [==============================] - 5s 3ms/step - loss: 0.0462 - val_loss: 0.0435 Epoch 19/20 1829/1829 [==============================] - 6s 3ms/step - loss: 0.0462 - val_loss: 0.0435 Epoch 20/20 1829/1829 [==============================] - 4s 2ms/step - loss: 0.0462 - val_loss: 0.0435
X_test = D1_test[['SBP_scaled', 'BMI_scaled']]
y_test = D1_test['event']
pred = deep_neural_net_D1.predict(X_test).flatten()
auc_deep_neural_net_D1 = roc_auc_score(D1_test.event, pred)
mse_deep_neural_net_D1 = np.mean(np.square(np.log(D1_test.mu) - np.log(pred)))
dev_deep_neural_net_D1 = -2*np.sum(D1_test.event*logit(pred) + np.log(1-pred))
kld_deep_neural_net_D1 = np.sum(D1_test.event*np.log(D1_test.event/pred))
print("Logistic regression")
print("===================")
print("ROC AUC (test): ", auc_log_reg_D1)
print("MSE wrt log ground truth: ", mse_log_reg_D1)
print("Logistic deviance: ", dev_log_reg_D1)
print("Kullback-Leibler divergence: ", kld_log_reg_D1)
print("\nCox regression")
print("==============")
print("ROC AUC (test): ", auc_cox_reg_D1)
print("MSE wrt log ground truth: ", mse_cox_reg_D1)
print("Logistic deviance: ", dev_cox_reg_D1)
print("Kullback-Leibler divergence: ", kld_cox_reg_D1)
print("\nShallow neural net")
print("==================")
print("ROC AUC (test): ", auc_shallow_neural_net_D1)
print("MSE wrt log ground truth: ", mse_shallow_neural_net_D1)
print("Logistic deviance: ", dev_shallow_neural_net_D1)
print("Kullback-Leibler divergence: ", kld_shallow_neural_net_D1)
print("\nDeep neural net")
print("==================")
print("ROC AUC (test): ", auc_deep_neural_net_D1)
print("MSE wrt log ground truth: ", mse_deep_neural_net_D1)
print("Logistic deviance: ", dev_deep_neural_net_D1)
print("Kullback-Leibler divergence: ", kld_deep_neural_net_D1)
prob_true, prob_pred = calibration_curve(D1_test.event, pred, n_bins=20, strategy='quantile')
plt.figure()
plt.plot(prob_pred, prob_true, marker='o')
plt.plot([0.004, 0.014], [0.004, 0.014], 'k')
plt.axis([0.004, 0.014, 0.004, 0.014])
plt.title('Calibration plot deep neural net')
plt.xlabel('Predicted probability')
plt.ylabel('True probability in each bin')
plt.show()
mx, my = np.meshgrid((sbp_range-min(sbp_range))/(max(sbp_range)-min(sbp_range)), (bmi_range-min(bmi_range))/(max(bmi_range)-min(bmi_range)))
bmi_sbp = pd.DataFrame({'SBP_scaled': mx.flatten(), 'BMI_scaled': my.flatten()})
pred = deep_neural_net_D1.predict(bmi_sbp).reshape(len(bmi_range), len(sbp_range))
fig = go.Figure(data=[go.Surface(z=pred/mu_ref, x=sbp_range, y=bmi_range)])
fig.update_traces(contours_z=dict(show=True, project_z=True), colorscale='Jet')
fig.update_layout(autosize=True, scene = dict(xaxis_title='SBP', yaxis_title='BMI', zaxis_title='Fitted 10 year relative CVD incidence risk'), width=1000, height=1000)
fig.show()
3228/3228 [==============================] - 5s 1ms/step Logistic regression =================== ROC AUC (test): 0.5617061723108522 MSE wrt log ground truth: 0.0015757922217920983 Logistic deviance: 9223.880067951093 Kullback-Leibler divergence: 3807.9559497679784 Cox regression ============== ROC AUC (test): 0.561706568045608 MSE wrt log ground truth: 0.0015788497360106327 Logistic deviance: 9223.87818207559 Kullback-Leibler divergence: 3807.969519644536 Shallow neural net ================== ROC AUC (test): 0.5582048039404497 MSE wrt log ground truth: 0.01092150969267656 Logistic deviance: 9232.214921846986 Kullback-Leibler divergence: 3839.3764256154614 Deep neural net ================== ROC AUC (test): 0.5604296185551024 MSE wrt log ground truth: 0.005701733303625341 Logistic deviance: 9227.724153258838 Kullback-Leibler divergence: 3793.78149458303
55/55 [==============================] - 0s 2ms/step
So many parameters, such an advanced model, so poor performance... Why???
It can be useful to investigate and understand how much data is actually available:
plt.figure(figsize=(20,14))
plt.plot(D1_train.loc[D1_train['event'] == 1, 'BMI'], D1_train.loc[D1_train['event'] == 1, 'SBP'], '.')
plt.title(str(np.sum(D1_train.event)) + ' events')
plt.xlabel('BMI')
plt.ylabel('SBP')
plt.show()
A simple "model" that calculates empirical probabilities/risk in a given granularity can also be helpful:
n = 10
counts, xbins, ybins = np.histogram2d(D1_train.BMI, D1_train.SBP, bins=(n, n))
sums, _, _ = np.histogram2d(D1_train.BMI, D1_train.SBP, weights=D1_train.event, bins=(xbins, ybins))
fig = go.Figure(data=[go.Surface(z=sums/counts/mu_ref, x=np.arange(0, 1, 1/n), y=np.arange(0, 1, 1/n))])
fig.update_traces(contours_z=dict(show=True, project_z=True), colorscale='Jet')
fig.update_layout(scene_aspectratio=dict(x=1, y=1, z=1), scene = dict(xaxis_title='Scaled SBP', yaxis_title='Scaled BMI', zaxis_title='Empirical 10 year relative CVD incidence risk'), width=1000, height=1000)
fig.show()
We conclude that dataset $D_1$ is not sufficiently rich to make neural networks (or any other ML model) outperform traditional models.
We thus construct a more complex dataset $D_2$ with
For $D_2$, we move from CVD as one of the main causes of death to total mortality as target variable and consider a wider age range from 0 to 84.
Source: https://wisqars.cdc.gov/data/lcd/home
See also https://www.causesofdeath.org/
Since QRISK3 is only defined for ages from 25 to 84, and since QRISK3 would tend to $-\infty$ for age to 0, we made the following adjustment for ages below 25
age = np.maximum(age, 25 - 10*np.tanh((25 - age)/10)) ,
which bends the age curve as follows:
plt.plot(np.arange(0,41), np.maximum(np.arange(0,41), 25 - 10*np.tanh((25 - np.arange(0,41))/10)))
plt.xlabel('age')
plt.ylabel('tanh(age)')
plt.show()
We take annual mortality rates $q_x$ of Swiss males in year 2019 from the Human Mortality Database, https://mortality.org:
qx = np.array([3.3900e-03,2.0000e-04,9.0000e-05,1.1000e-04,1.8000e-04,1.1000e-04, 4.0000e-05,1.6000e-04,5.0000e-05,5.0000e-05,7.0000e-05,2.0000e-05,
9.0000e-05,5.0000e-05,9.0000e-05,3.4000e-04,1.4000e-04,1.9000e-04, 2.9000e-04,4.7000e-04,3.8000e-04,5.0000e-04,4.0000e-04,4.9000e-04,
4.5000e-04,4.9000e-04,3.7000e-04,4.6000e-04,3.3000e-04,4.1000e-04, 3.1000e-04,4.4000e-04,5.2000e-04,4.7000e-04,5.3000e-04,5.5000e-04,
3.9000e-04,5.6000e-04,6.3000e-04,7.2000e-04,9.1000e-04,8.6000e-04, 9.5000e-04,9.7000e-04,9.4000e-04,1.1100e-03,1.2500e-03,1.3200e-03,
1.4500e-03,1.9700e-03,2.2900e-03,1.9800e-03,2.3600e-03,2.6800e-03, 3.0800e-03,3.4700e-03,3.8900e-03,4.1700e-03,4.3100e-03,5.3200e-03,
5.6700e-03,6.6200e-03,7.1200e-03,7.3800e-03,8.6100e-03,9.5000e-03, 1.1140e-02,1.1490e-02,1.2670e-02,1.2730e-02,1.5300e-02,1.6960e-02,
1.9290e-02,2.1150e-02,2.2790e-02,2.4540e-02,2.6770e-02,3.0400e-02, 3.2880e-02,4.0410e-02,4.1910e-02,4.7380e-02,5.6770e-02,6.6290e-02,
7.2150e-02,8.2240e-02,9.6120e-02,1.0834e-01,1.2484e-01,1.3984e-01, 1.5434e-01,1.7482e-01,1.9335e-01,2.1427e-01,2.2078e-01,2.6159e-01,
2.8584e-01,3.1061e-01,3.3564e-01,3.6064e-01,3.8533e-01,4.0946e-01, 4.3277e-01,4.5506e-01,4.7616e-01,4.9595e-01,5.1433e-01,5.3128e-01, 5.4679e-01,5.6087e-01])
age_range_ext = np.arange(0, 85)
plt.plot(age_range_ext, np.log(qx[age_range_ext]))
plt.xlabel('Age')
plt.ylabel('log annual mortality rate')
plt.title('Swiss males in year 2019')
plt.show()
Calculating 10-year mortality rates $_{10}q_x$:
qx10 = 1 - np.prod(1-np.lib.stride_tricks.sliding_window_view(qx, 10), axis = 1)
plt.plot(age_range_ext, np.log(qx10[age_range_ext]))
plt.xlabel('Age')
plt.ylabel('log 10-year mortality rate')
plt.title('Swiss males in year 2019')
plt.show()
We fit a polynomial of degree 14 to the difference between log QRISK3 and log 10-year mortality.
qrisk = qrisk3(age_range_ext, gender, height, weight, tcl_hdl_ratio, sbp, sd_sbp)
plt.plot(age_range_ext, np.log(qx10[age_range_ext]), 'g')
plt.plot(age_range_ext, np.log(qrisk[age_range_ext]), 'b')
plt.ylim(-7.5, 0)
polyfit = np.polyfit(age_range_ext, np.log(qx10[age_range_ext])-np.log(qrisk), 14)
residual = np.poly1d(polyfit)
plt.plot(age_range_ext, residual(age_range_ext) + np.log(qrisk), 'r')
plt.xlabel('Age')
plt.ylabel('log 10-year mortality/CVD risk')
plt.legend(['log 10-year mortality', 'log QRISK3', 'log QRISK3 + polynomial fit'])
plt.show()
Finally, we extend QRISK3 to mortality by defining:
def qrisk3_ext(age, gender, height, weight, tcl_hdl_ratio, sbp, sd_sbp, num1, num2, num3, binary):
q = np.exp(np.log(qrisk3(age, gender, height, weight, tcl_hdl_ratio, sbp, sd_sbp)) + residual(age)
+ 16*(num1-0.5)**4 + 4*(num2-0.5)**2*num3 + num3 + binary - 1.65)
return np.maximum(0.0001, np.minimum(q, 1))
We define several more ranges, including those for 3 abstract, continuous variables num1, num2, num3, e.g., stepcounts, triglycerids, resting heartrate, etc. Also, we introduce another binary variable called binary.
tcl_hdl_ratio_range = np.arange(2.5, 7.5, 0.5)
sd_sbp_range = np.arange(5, 15.5, 0.5)
num1_range = np.arange(0, 1.05, 0.05)
num2_range = np.arange(0, 1.05, 0.05)
num3_range = np.arange(0, 1.05, 0.05)
Let's simulate the second dataset $D_2$. Apart from the same dependency between BMI and SBP of dataset $D_1$, we use another non-trivial dependency between binary and num3:
D2.BINARY = np.round((D2.NUM3 + 3*np.random.uniform(low=0, high=1, size=persons))/4)
persons = 2000000
D2 = pd.DataFrame(index = range(persons), columns = ['AGE', 'GENDER', 'BMI', 'SBP', 'TCL_HDL_RATIO', 'SD_SBP', 'NUM1', 'NUM2', 'NUM3', 'BINARY', 'mu', 'T', 'event'])
np.random.seed(0)
sbp_bmi = np.random.multivariate_normal([125, 3.2], np.matmul(np.matmul(np.diag([15, 0.25]), np.array([[1, 0.25], [0.25, 1]])), np.diag([15, 0.25])), size=persons)
D2.SBP = sbp_bmi[:,0]
D2.BMI = np.exp(sbp_bmi[:,1])
D2.AGE = np.random.uniform(low=0, high=84, size=persons).astype(int)
D2.GENDER = np.random.binomial(1, 0.5, size=persons)
D2.TCL_HDL_RATIO = np.random.uniform(low=min(tcl_hdl_ratio_range), high=max(tcl_hdl_ratio_range), size=persons)
D2.SD_SBP = np.random.uniform(low=min(sd_sbp_range), high=max(sd_sbp_range), size=persons)
D2.NUM1 = np.random.uniform(low=0, high=1, size=persons)
D2.NUM2 = np.random.uniform(low=0, high=1, size=persons)
D2.NUM3 = np.random.uniform(low=0, high=1, size=persons)
D2.BINARY = np.round((D2.NUM3 + 3*np.random.uniform(low=0, high=1, size=persons))/4)
q_male = qrisk3_ext(D2.AGE, "male", height, np.exp(sbp_bmi[:,1])*(height/100)**2, D2.TCL_HDL_RATIO, sbp_bmi[:,0], D2.SD_SBP, D2.NUM1, D2.NUM2, D2.NUM3, D2.BINARY)
D2.mu = qrisk3_ext(D2.AGE, "female", height, np.exp(sbp_bmi[:,1])*(height/100)**2, D2.TCL_HDL_RATIO, sbp_bmi[:,0], D2.SD_SBP, D2.NUM1, D2.NUM2, D2.NUM3, D2.BINARY)
D2.loc[D2.GENDER == 0, "mu"] = q_male[D2.GENDER == 0]
D2.event = np.where(np.random.uniform(0, 1, persons) < D2.mu, 1, 0)
D2["T"] = D2.apply(lambda row : max_T if row.event == 0 else draw_T(row.mu), axis=1)
D2 = D2[(D2.SBP>=min(sbp_range)) & (D2.SBP<=max(sbp_range)) & (D2.BMI>=min(bmi_range)) & (D2.BMI<=max(bmi_range))].reset_index()
print("Number of events: ", np.sum(D2.event))
print("Number of individuals: ", len(D2))
print("Number of person years: ", int(np.sum(D2["T"])))
# min-max scaling covariates
D2['BMI_scaled'] = (D2.BMI-min(bmi_range))/(max(bmi_range)-min(bmi_range))
D2['SBP_scaled'] = (D2.SBP-min(sbp_range))/(max(sbp_range)-min(sbp_range))
D2['SD_SBP_scaled'] = (D2.SD_SBP-min(sd_sbp_range))/(max(sd_sbp_range)-min(sd_sbp_range))
D2['TCL_HDL_RATIO_scaled'] = (D2.TCL_HDL_RATIO-min(tcl_hdl_ratio_range))/(max(tcl_hdl_ratio_range)-min(tcl_hdl_ratio_range))
D2_train, D2_test = train_test_split(D2, test_size = 0.15, random_state = 0)
D2_train.head()
plt.figure(figsize=(10,6))
plt.plot(D2.AGE + np.random.uniform(-0.3, 0.3, len(D2)), np.log(D2.mu), 'b.', markersize=0.1)
plt.plot(age_range_ext, residual(age_range_ext) + np.log(qrisk), 'r')
plt.xlabel('Age')
plt.ylabel('Simulated log 10-year mortality rates')
plt.show()
Number of events: 144773 Number of individuals: 1376395 Number of person years: 12959172
As a reference baseline for ROC AUC we sort $D_{2, \mathrm{test}}$ by age - 6 times gender.
age_gender = D2_test.copy()
age_gender["AGE_GENDER"] = age_gender.AGE - age_gender.GENDER*6
age_gender = age_gender.sort_values(by=["AGE_GENDER"]).reset_index()
print("ROC AUC of sorting by age and gender: ", roc_auc_score(age_gender.event, age_gender.index.tolist()))
ROC AUC of sorting by age and gender: 0.8781611144733502
lin_formula = "AGE + GENDER + SBP_scaled + BMI_scaled + I(BMI_scaled**2) + SD_SBP_scaled + TCL_HDL_RATIO_scaled + NUM1 + NUM2 + NUM3 + BINARY"
log_reg_D2 = sm.logit(formula='event ~ ' + lin_formula, data=D2_train).fit()
print(log_reg_D2.summary())
pred = log_reg_D2.predict(D2_test)
auc_log_reg_D2 = roc_auc_score(D2_test.event, pred)
mse_log_reg_D2 = np.mean(np.square(np.log(D2_test.mu) - np.log(pred)))
dev_log_reg_D2 = -2*np.sum(D2_test.event*logit(pred) + np.log(1-pred))
kld_log_reg_D2 = np.sum(D2_test.event*np.log(D2_test.event/pred))
print("ROC AUC (test): ", auc_log_reg_D2)
print("MSE wrt log ground truth: ", mse_log_reg_D2)
print("Logistic deviance: ", dev_log_reg_D2)
print("Kullback-Leibler divergence: ", kld_log_reg_D2)
prob_true, prob_pred = calibration_curve(D2_test.event, pred, n_bins=50, strategy='quantile')
plt.figure(figsize=(10,6))
plt.plot(prob_pred, prob_true, marker='o')
plt.plot([0, 1], [0, 1], 'k')
plt.title('Calibration plot logistic regression')
plt.xlabel('Predicted probability')
plt.ylabel('True probability in each bin')
plt.show()
Optimization terminated successfully.
Current function value: 0.207599
Iterations 9
Logit Regression Results
==============================================================================
Dep. Variable: event No. Observations: 1169935
Model: Logit Df Residuals: 1169923
Method: MLE Df Model: 11
Date: Wed, 20 Sep 2023 Pseudo R-squ.: 0.3830
Time: 17:35:57 Log-Likelihood: -2.4288e+05
converged: True LL-Null: -3.9362e+05
Covariance Type: nonrobust LLR p-value: 0.000
========================================================================================
coef std err z P>|z| [0.025 0.975]
----------------------------------------------------------------------------------------
Intercept -10.0956 0.030 -337.084 0.000 -10.154 -10.037
AGE 0.0983 0.000 338.391 0.000 0.098 0.099
GENDER -0.3770 0.007 -51.109 0.000 -0.391 -0.363
SBP_scaled 0.4083 0.015 28.120 0.000 0.380 0.437
BMI_scaled -0.4053 0.054 -7.439 0.000 -0.512 -0.298
I(BMI_scaled ** 2) 0.5112 0.060 8.544 0.000 0.394 0.628
SD_SBP_scaled 0.0860 0.013 6.763 0.000 0.061 0.111
TCL_HDL_RATIO_scaled 0.8605 0.013 67.024 0.000 0.835 0.886
NUM1 -0.0065 0.013 -0.516 0.606 -0.031 0.018
NUM2 0.0053 0.013 0.415 0.678 -0.020 0.030
NUM3 1.8316 0.013 136.103 0.000 1.805 1.858
BINARY 1.3514 0.008 169.867 0.000 1.336 1.367
========================================================================================
ROC AUC (test): 0.9054449129641041
MSE wrt log ground truth: 1.738145237669142
Logistic deviance: 85382.84694950932
Kullback-Leibler divergence: 27648.251502104344
plt.figure(figsize=(10,6))
kmf = ll.KaplanMeierFitter()
kmf.fit(D2_train['T'], event_observed=D2_train.event)
kmf.plot_survival_function()
plt.figure(figsize=(10,6))
cph_D2 = ll.CoxPHFitter()
cph_D2.fit(D2_train, 'T', event_col='event', formula=lin_formula)
cph_D2.print_summary()
cph_D2.plot()
survival = cph_D2.predict_survival_function(D2_test)
pred = 1 - np.squeeze(np.array(survival.iloc[-1,:]))
auc_cox_reg_D2 = roc_auc_score(D2_test.event, pred)
mse_cox_reg_D2 = np.mean(np.square(np.log(D2_test.mu) - np.log(pred)))
dev_cox_reg_D2 = -2*np.sum(D2_test.event*logit(pred) + np.log(1-pred))
kld_cox_reg_D2 = np.sum(D2_test.event*np.log(D2_test.event/pred))
print("ROC AUC (test): ", auc_cox_reg_D2)
print("MSE wrt log ground truth: ", mse_cox_reg_D2)
print("Logistic deviance: ", dev_cox_reg_D2)
print("Kullback-Leibler divergence: ", kld_cox_reg_D2)
prob_true, prob_pred = calibration_curve(D2_test.event, pred, n_bins=50, strategy='quantile')
plt.figure(figsize=(10,6))
plt.plot(prob_pred, prob_true, marker='o')
plt.plot([0, 1], [0, 1], 'k')
plt.title('Calibration plot Cox regression')
plt.xlabel('Predicted probability')
plt.ylabel('True probability in each bin')
plt.show()
| model | lifelines.CoxPHFitter |
|---|---|
| duration col | 'T' |
| event col | 'event' |
| baseline estimation | breslow |
| number of observations | 1.16994e+06 |
| number of events observed | 123123 |
| partial log-likelihood | -1548453.70 |
| time fit was run | 2023-09-20 17:35:58 UTC |
| coef | exp(coef) | se(coef) | coef lower 95% | coef upper 95% | exp(coef) lower 95% | exp(coef) upper 95% | cmp to | z | p | -log2(p) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| AGE | 0.10 | 1.10 | 0.00 | 0.10 | 0.10 | 1.10 | 1.10 | 0.00 | 373.91 | <0.005 | inf |
| GENDER | -0.37 | 0.69 | 0.01 | -0.38 | -0.36 | 0.68 | 0.70 | 0.00 | -64.63 | <0.005 | inf |
| SBP_scaled | 0.35 | 1.42 | 0.01 | 0.33 | 0.37 | 1.39 | 1.45 | 0.00 | 31.02 | <0.005 | 699.23 |
| BMI_scaled | -0.42 | 0.66 | 0.04 | -0.50 | -0.34 | 0.61 | 0.71 | 0.00 | -9.94 | <0.005 | 74.89 |
| I(BMI_scaled**2) | 0.46 | 1.58 | 0.05 | 0.37 | 0.55 | 1.44 | 1.73 | 0.00 | 9.90 | <0.005 | 74.30 |
| SD_SBP_scaled | 0.09 | 1.09 | 0.01 | 0.07 | 0.11 | 1.07 | 1.11 | 0.00 | 9.00 | <0.005 | 61.98 |
| TCL_HDL_RATIO_scaled | 0.83 | 2.30 | 0.01 | 0.81 | 0.85 | 2.25 | 2.34 | 0.00 | 83.20 | <0.005 | inf |
| NUM1 | -0.01 | 0.99 | 0.01 | -0.03 | 0.01 | 0.97 | 1.01 | 0.00 | -1.36 | 0.17 | 2.53 |
| NUM2 | 0.00 | 1.00 | 0.01 | -0.02 | 0.02 | 0.98 | 1.02 | 0.00 | 0.09 | 0.93 | 0.10 |
| NUM3 | 1.79 | 6.00 | 0.01 | 1.77 | 1.81 | 5.87 | 6.12 | 0.00 | 167.09 | <0.005 | inf |
| BINARY | 1.28 | 3.60 | 0.01 | 1.27 | 1.29 | 3.55 | 3.65 | 0.00 | 194.75 | <0.005 | inf |
| Concordance | 0.89 |
|---|---|
| Partial AIC | 3096929.41 |
| log-likelihood ratio test | 330317.33 on 11 df |
| -log2(p) of ll-ratio test | inf |
ROC AUC (test): 0.9055297903483172 MSE wrt log ground truth: 1.7546013426946472 Logistic deviance: 83994.21928553453 Kullback-Leibler divergence: 26474.98484694554
xvars = ['AGE', 'GENDER', 'SBP_scaled', 'BMI_scaled', 'TCL_HDL_RATIO_scaled', 'SD_SBP_scaled', 'NUM1', 'NUM2', 'NUM3', 'BINARY']
X_train = D2_train[xvars]
y_train = D2_train.event
neural_net_D2 = Sequential([
Input(shape=(10,)),
Dense(256, kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0)),
Activation(tf.keras.activations.relu),
Dense(128, kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0)),
Activation(tf.keras.activations.relu),
Dense(64, kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0)),
Activation(tf.keras.activations.relu),
Dense(1, activation = "sigmoid", kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0))
])
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.001, decay_steps=1000, decay_rate=0.99)
opt = tf.keras.optimizers.Adam(learning_rate = lr_schedule)
neural_net_D2.compile(optimizer = opt, loss = 'binary_crossentropy')
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience = 30, restore_best_weights=True)
history = neural_net_D2.fit(
X_train,
y_train,
batch_size=64,
epochs=100,
shuffle=True,
validation_split=0.20,
verbose=1,
callbacks=[callback]
)
plt.figure(figsize=(10,6))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()
print(neural_net_D2.summary())
X_test = D2_test[xvars]
y_test = D2_test.event
pred = neural_net_D2.predict(X_test).flatten()
auc_neural_net_D2 = roc_auc_score(D2_test.event, pred)
mse_neural_net_D2 = np.mean(np.square(np.log(D2_test.mu) - np.log(pred)))
dev_neural_net_D2 = -2*np.sum(D2_test.event*logit(pred) + np.log(1-pred))
kld_neural_net_D2 = np.sum(D2_test.event*np.log(D2_test.event/pred))
print("ROC AUC (test): ", auc_neural_net_D2)
print("MSE wrt log ground truth: ", mse_neural_net_D2)
print("Logistic deviance: ", dev_neural_net_D2)
print("Kullback-Leibler divergence: ", kld_neural_net_D2)
prob_true, prob_pred = calibration_curve(D2_test.event, pred, n_bins=50, strategy='quantile')
plt.figure(figsize=(10,6))
plt.plot(prob_pred, prob_true, marker='o')
plt.plot([0, 1], [0, 1], 'k')
plt.title('Calibration plot neural network')
plt.xlabel('Predicted probability')
plt.ylabel('True probability in each bin')
plt.show()
Epoch 1/100 14625/14625 [==============================] - 55s 4ms/step - loss: 0.2126 - val_loss: 0.1959 Epoch 2/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1990 - val_loss: 0.1985 Epoch 3/100 14625/14625 [==============================] - 53s 4ms/step - loss: 0.1970 - val_loss: 0.1930 Epoch 4/100 14625/14625 [==============================] - 56s 4ms/step - loss: 0.1937 - val_loss: 0.1920 Epoch 5/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1915 - val_loss: 0.1901 Epoch 6/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1896 - val_loss: 0.1868 Epoch 7/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1887 - val_loss: 0.1857 Epoch 8/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1880 - val_loss: 0.1856 Epoch 9/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1876 - val_loss: 0.1855 Epoch 10/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1870 - val_loss: 0.1866 Epoch 11/100 14625/14625 [==============================] - 51s 3ms/step - loss: 0.1865 - val_loss: 0.1842 Epoch 12/100 14625/14625 [==============================] - 55s 4ms/step - loss: 0.1862 - val_loss: 0.1839 Epoch 13/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1859 - val_loss: 0.1839 Epoch 14/100 14625/14625 [==============================] - 51s 3ms/step - loss: 0.1856 - val_loss: 0.1839 Epoch 15/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1854 - val_loss: 0.1847 Epoch 16/100 14625/14625 [==============================] - 51s 3ms/step - loss: 0.1852 - val_loss: 0.1833 Epoch 17/100 14625/14625 [==============================] - 49s 3ms/step - loss: 0.1850 - val_loss: 0.1832 Epoch 18/100 14625/14625 [==============================] - 49s 3ms/step - loss: 0.1848 - val_loss: 0.1845 Epoch 19/100 14625/14625 [==============================] - 49s 3ms/step - loss: 0.1847 - val_loss: 0.1833 Epoch 20/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1846 - val_loss: 0.1831 Epoch 21/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1845 - val_loss: 0.1835 Epoch 22/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1844 - val_loss: 0.1830 Epoch 23/100 14625/14625 [==============================] - 53s 4ms/step - loss: 0.1843 - val_loss: 0.1831 Epoch 24/100 14625/14625 [==============================] - 53s 4ms/step - loss: 0.1842 - val_loss: 0.1830 Epoch 25/100 14625/14625 [==============================] - 53s 4ms/step - loss: 0.1842 - val_loss: 0.1830 Epoch 26/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1841 - val_loss: 0.1830 Epoch 27/100 14625/14625 [==============================] - 51s 3ms/step - loss: 0.1841 - val_loss: 0.1831 Epoch 28/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1841 - val_loss: 0.1828 Epoch 29/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1840 - val_loss: 0.1829 Epoch 30/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1840 - val_loss: 0.1829 Epoch 31/100 14625/14625 [==============================] - 56s 4ms/step - loss: 0.1840 - val_loss: 0.1828 Epoch 32/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1839 - val_loss: 0.1828 Epoch 33/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1839 - val_loss: 0.1829 Epoch 34/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1839 - val_loss: 0.1828 Epoch 35/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1839 - val_loss: 0.1828 Epoch 36/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1839 - val_loss: 0.1828 Epoch 37/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1828 Epoch 38/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1828 Epoch 39/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1828 Epoch 40/100 14625/14625 [==============================] - 55s 4ms/step - loss: 0.1838 - val_loss: 0.1828 Epoch 41/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1828 Epoch 42/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 43/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 44/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 45/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 46/100 14625/14625 [==============================] - 57s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 47/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 48/100 14625/14625 [==============================] - 55s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 49/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 50/100 14625/14625 [==============================] - 51s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 51/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 52/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 53/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 54/100 14625/14625 [==============================] - 55s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 55/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 56/100 14625/14625 [==============================] - 56s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 57/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 58/100 14625/14625 [==============================] - 51s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 59/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 60/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 61/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 62/100 14625/14625 [==============================] - 49s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 63/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 64/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 65/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 66/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 67/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 68/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 69/100 14625/14625 [==============================] - 55s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 70/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 71/100 14625/14625 [==============================] - 50s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 72/100 14625/14625 [==============================] - 53s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 73/100 14625/14625 [==============================] - 51s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 74/100 14625/14625 [==============================] - 51s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 75/100 14625/14625 [==============================] - 56s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 76/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 77/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 78/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 79/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 80/100 14625/14625 [==============================] - 56s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 81/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 82/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 83/100 14625/14625 [==============================] - 53s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 84/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 85/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 86/100 14625/14625 [==============================] - 53s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 87/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 88/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 89/100 14625/14625 [==============================] - 53s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 90/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 91/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 92/100 14625/14625 [==============================] - 52s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 93/100 14625/14625 [==============================] - 51s 3ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 94/100 14625/14625 [==============================] - 56s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 95/100 14625/14625 [==============================] - 51s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 96/100 14625/14625 [==============================] - 53s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 97/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 98/100 14625/14625 [==============================] - 58s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 99/100 14625/14625 [==============================] - 57s 4ms/step - loss: 0.1838 - val_loss: 0.1827 Epoch 100/100 14625/14625 [==============================] - 54s 4ms/step - loss: 0.1838 - val_loss: 0.1827
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_4 (Dense) (None, 256) 2816
activation_2 (Activation) (None, 256) 0
dense_5 (Dense) (None, 128) 32896
activation_3 (Activation) (None, 128) 0
dense_6 (Dense) (None, 64) 8256
activation_4 (Activation) (None, 64) 0
dense_7 (Dense) (None, 1) 65
=================================================================
Total params: 44033 (172.00 KB)
Trainable params: 44033 (172.00 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None
6452/6452 [==============================] - 10s 2ms/step
ROC AUC (test): 0.9205069049756238
MSE wrt log ground truth: 0.10821129939518413
Logistic deviance: 75732.18857970255
Kullback-Leibler divergence: 24621.50699719608
Overview of model performance on $D_{2, \mathrm{test}}$.
print("Logistic regression")
print("===================")
print("ROC AUC (test): ", auc_log_reg_D2)
print("MSE wrt log ground truth: ", mse_log_reg_D2)
print("Logistic deviance: ", dev_log_reg_D2)
print("Kullback-Leibler divergence: ", kld_log_reg_D2)
print("\nCox regression")
print("==============")
print("ROC AUC (test): ", auc_cox_reg_D2)
print("MSE wrt log ground truth: ", mse_cox_reg_D2)
print("Logistic deviance: ", dev_cox_reg_D2)
print("Kullback-Leibler divergence: ", kld_cox_reg_D2)
print("\nNeural net")
print("==================")
print("ROC AUC (test): ", auc_neural_net_D2)
print("MSE wrt log ground truth: ", mse_neural_net_D2)
print("Logistic deviance: ", dev_neural_net_D2)
print("Kullback-Leibler divergence: ", kld_neural_net_D2)
Logistic regression =================== ROC AUC (test): 0.9054449129641041 MSE wrt log ground truth: 1.738145237669142 Logistic deviance: 85382.84694950932 Kullback-Leibler divergence: 27648.251502104344 Cox regression ============== ROC AUC (test): 0.9055297903483172 MSE wrt log ground truth: 1.7546013426946472 Logistic deviance: 83994.21928553453 Kullback-Leibler divergence: 26474.98484694554 Neural net ================== ROC AUC (test): 0.9205069049756238 MSE wrt log ground truth: 0.10821129939518413 Logistic deviance: 75732.18857970255 Kullback-Leibler divergence: 24621.50699719608
From classic statistical modeling, we are used to study the "interior" of a model by looking at estimated model coefficients:
XAI (eXplainable AI) is a collection of methods to study such aspects for any model, e.g., a neural network or a logistic regression.
Important even when no one asks for them! (Why?)
Examples?
Aspects 2 and 3 via:
For simplicity, we ignore the Cox model here.
# Show train and test AUC for log reg and nn
print("Logistic regression")
print("-Test AUC", roc_auc_score(y_test, log_reg_D2.predict(X_test)))
print("-Train AUC", roc_auc_score(y_train, log_reg_D2.predict(X_train)))
print("\nNeural net")
print("-Test AUC", roc_auc_score(y_test, neural_net_D2.predict(X_test, batch_size=100_000, verbose=False)))
print("-Train AUC", roc_auc_score(y_train, neural_net_D2.predict(X_train, batch_size=100_000, verbose=False)))
Logistic regression -Test AUC 0.9054449129641041 -Train AUC 0.9047912732610746 Neural net -Test AUC 0.9205069059753399 -Train AUC 0.9205863326848361
Comments:
Besides model specific measures of variable importance (?), also model agnostic measures exist, particularly:
Permutation Variable Importance of $j$-th feature and data $D$:
Variants?
log_reg_explainer = dx.Explainer(
log_reg_D2,
data=X_test,
y=y_test,
predict_function=lambda m, X: np.array(m.predict(X)),
verbose=False,
)
log_reg_explainer.model_parts(
loss_function='1-auc', N=10000, B=1, type="difference"
).plot(title="Logistic regression")
def neural_net_predict(m, X):
return np.squeeze(np.array(m.predict(X, batch_size=100_000, verbose=False)))
neural_net_explainer = dx.Explainer(
neural_net_D2,
data=X_test,
y=y_test,
predict_function=neural_net_predict,
verbose=False,
)
neural_net_explainer.model_parts(
loss_function='1-auc', N = 10000, B=1, type="difference"
).plot(title="Neural net")
Comments:
In logistic regression (without interactions), we can read off the effect of $j$-th feature from the fitted coefficients $\hat \beta_j$:
"A 1-unit increase in $x_j$ is associated with a change in the log odds by $\hat \beta_j$, keeping everything else fixed (Ceteris Paribus)."
For a black-box, due to complex interactions, this effect varies from observation to observation.
What can we say about (Ceteris Paribus) effect of $j$-th feature?
Working on "link" space (here logit) usually better (why?)
Ceteris Paribus is curse and blessing (why?)
vars = ["AGE", "GENDER"]
def logit(x):
return np.log(x / (1 - x))
def neural_net_predict_logit(m, X):
return logit(neural_net_predict(m, X))
log_reg_explainer_logit = dx.Explainer(
log_reg_D2,
data=X_test,
y=y_test,
predict_function=lambda m, X: logit(np.array(m.predict(X))),
verbose=False,
)
neural_net_explainer_logit = dx.Explainer(
neural_net_D2,
data=X_test,
y=y_test,
predict_function=neural_net_predict_logit,
verbose=False,
)
log_reg_explainer_logit.model_profile(
variables=vars, N=100, grid_points=49, verbose=False,
).plot(geom="profiles", title="Logistic regression (log-odds scale)")
neural_net_explainer_logit.model_profile(
variables=vars, N=100, grid_points=49, verbose=False,
).plot(geom="profiles", title="Neural net (log-odds scale)")
Comments:
Same on probability scale
log_reg_explainer.model_profile(
variables=vars, N=100, grid_points=49, verbose=False,
).plot(geom="profiles", title="Logistic regression (probability scale)")
neural_net_explainer.model_profile(
variables=vars, N=100, grid_points=49, verbose=False,
).plot(geom="profiles", title="Neural net (probability scale)")
Applying the ideas discussed in M. Lindholm, R. Richman, A. Tsanakas, and M.V. Wüthrich. Discrimination-free insurance pricing: Consider 4 subsets of individuals from X_test for best estimate and non-discriminatory prices: (num3 < 0.5, binary = 0), (num3 >= 0.5, binary = 0), (num3 < 0.5, binary = 1), (num3 >= 0.5, binary = 1) consider num3 < 0.5 and num3 >= 0.5 for unawareness prices. Note the difference of the unawareness prices vs discrimination-free prices. The higher second unawareness price (0.1418) is considered discriminatory. Privacy preserving methods thus can also help to mitigate this type of discrimination from unawareness.
X_test_00 = X_test.loc[(X_test.NUM3 < 0.5) & (X_test.BINARY == 0)]
X_test_01 = X_test.loc[(X_test.NUM3 < 0.5) & (X_test.BINARY == 1)]
X_test_10 = X_test.loc[(X_test.NUM3 >= 0.5) & (X_test.BINARY == 0)]
X_test_11 = X_test.loc[(X_test.NUM3 >= 0.5) & (X_test.BINARY == 1)]
N_00 = len(X_test_00)
N_01 = len(X_test_01)
N_10 = len(X_test_10)
N_11 = len(X_test_11)
P_00 = np.mean(y_test.loc[(X_test.NUM3 < 0.5) & (X_test.BINARY == 0)])
P_01 = np.mean(y_test.loc[(X_test.NUM3 < 0.5) & (X_test.BINARY == 1)])
P_10 = np.mean(y_test.loc[(X_test.NUM3 >= 0.5) & (X_test.BINARY == 0)])
P_11 = np.mean(y_test.loc[(X_test.NUM3 >= 0.5) & (X_test.BINARY == 1)])
P_0 = np.mean(y_test.loc[(X_test.NUM3 < 0.5)])
P_1 = np.mean(y_test.loc[(X_test.NUM3 >= 0.5)])
print("Exposures: ", [N_00, N_01, N_10, N_11])
print("Best estimate prices: ", [P_00, P_01, P_10, P_11])
print("Unawareness prices for NUM3 < 0.5 and NUM3 >= 0.5: ", [P_0, P_1])
print("Non-discriminatory prices for NUM3 < 0.5 and NUM3 >= 0.5: ", [(P_00*(N_00+N_10)+P_01*(N_01+N_11))/(N_00+N_01+N_10+N_11), (P_10*(N_00+N_10)+P_11*(N_01+N_11))/(N_00+N_01+N_10+N_11)])
Exposures: [60371, 42947, 42760, 60382] Best estimate prices: [0.03929038776896192, 0.10827298763592334, 0.08126753975678204, 0.18470736312146005] Unawareness prices for NUM3 < 0.5 and NUM3 >= 0.5: [0.06796492382740665, 0.1418238932733513] Non-discriminatory prices for NUM3 < 0.5 and NUM3 >= 0.5: [0.07381476567099261, 0.13303705205189398]
In this last part we showcase the concept of homomorphic encryption by first introducing several helper functions to
import random, array, time
import numpy as np
from math import exp, log, sqrt
def gcd_ext(a, b):
# returns gcd, x, y, such that gcd = a*x + b*y
if a == 0: return b, 0, 1
d, x, y = gcd_ext(b % a, a)
return d, y - (b // a) * x, x
def jacobi(a, n):
# Jacobi symbol (generalized Legendre symbol) for positive, odd n
a = a % n
t = 1
while a != 0:
while a % 2 == 0:
a = a // 2
r = n % 8
if r == 3 or r == 5: t = -t
a, n = n, a
if a % 4 == 3 and n % 4 == 3: t = -t
a = a % n
if n == 1: return t
else: return 0
def is_probable_prime(n, k = 20):
# Solovay-Strassen probabilistic prime number test
if n == 2: return True
if n % 2 == 0 or n == 1: return False
for j in range(k):
r = random.randint(2, n-1)
x = jacobi(r, n) % n
if x == 0 or pow(r, (n-1)//2, n) != x:
return False
return True
def next_probable_prime(n):
if n % 2 == 0: n += 1
while not is_probable_prime(n): n += 2
return n
def text_to_int(text):
return int.from_bytes(text.encode(), "little")
def int_to_text(n):
try:
out = n.to_bytes(int(log(n)/log(256)+1), "little").decode()
except Exception:
out = "Invalid integer"
return out
def rsa_attack(n, e):
# retrieve private key from public key by factoring n
factors = MPQS(n).factor()
p = factors[0]
q = factors[1]
d = gcd_ext(e, (p-1)*(q-1))[1] % ((p-1)*(q-1))
return n, d
The next code block is an implementation of the quadratic sieve to factor integers, see, e.g., https://www.ams.org/notices/199612/pomerance.pdf
class MPQS:
def __init__(self, n, max_trial_division=1e6, sieve_window_size=100000, multiplier=10):
# n, number to be factored
# max_trial_division, fist perform a trial division up to primes of size max_trial_division
# sieve_window_size and m are tuning parameters that will impact computation times
if not n>1 and n%1 == 1: raise ValueError("Invalid n")
print("Trying to factor {}-digit number {}".format(len(str(n)), n))
self.sieve_window_size = sieve_window_size
self.max_trial_division = int(max_trial_division)
self.multiplier = multiplier
_, self.n, self.power = self.is_power(n, max_trial_division)
self.factors, self.n = self.trial_division(self.n, self.max_trial_division)
self.sqrt_n = sqrt(self.n)
self.b = int(exp(sqrt(log(self.n)*log(log(self.n)+1e-16))/2))
self.factor_base = [-1, 2] + [p for p in self.first_primes(self.b)[1:] if jacobi(self.n, p) == 1]
self.sqrt_p, self.log_p = self.get_roots_logs(self.factor_base, self.n)
self.exponent_vectors = np.zeros((len(self.factor_base), len(self.factor_base)), dtype=bool)
self.X_list = [0] * len(self.factor_base)
self.Y_list = [0] * len(self.factor_base)
self.row = [-1] * len(self.factor_base)
self.numbers_found = 0
self.M = self.multiplier * len(self.factor_base)
self.A = 0
self.B = 0
self.C = 0
self.D = int(sqrt(sqrt(2*self.n)/self.M))
self.D_inv = 0
self.root1 = [0] * len(self.factor_base)
self.root2 = [0] * len(self.factor_base)
def factor(self):
print("MPQS trying to factor {}-digit number {}".format(len(str(self.n)), self.n))
if is_probable_prime(self.n): return np.repeat(self.factors + [self.n], self.power).tolist()
if self.n == 1: return np.repeat(self.factors, self.power).tolist()
factors = self.factor_into_two()
out = []
for p in factors:
if is_probable_prime(p): out.append(p)
else: out += MPQS(p, self.max_trial_division, self.sieve_window_size, self.multiplier).factor()
print("\n{}".format(np.repeat(self.factors + out, self.power).tolist()))
return np.repeat(self.factors + out, self.power).tolist()
def trial_division(self, n, max_trial_division):
out = []
for p in self.first_primes(max_trial_division):
while n % p == 0:
n //= p
out.append(p)
return out, n
def is_power(self, n, min_base):
for r in range(2, int(log(n)/log(min_base) + 2)):
p = self.int_root(n, r)
if p != -1: return True, p, r
return False, n, 1
def int_root(_, n, r):
low = int(n**(1/(r+0.1)))
high = int(n**(1/(r-0.1)))
while high - low > 2:
center = high//2 + low//2
if center**r > n: high = center
else: low = center
if high**r == n: return high
if (high-1)**r == n: return high - 1
if low**r == n: return low
return -1
def first_primes(_, n):
is_prime = np.ones(n+1, dtype=bool)
out = []
i = 2
while 1:
while i <= n and not is_prime[i]: i = i + 1
if i == n+1: return out
out.append(i)
idx = np.arange(2*i, n + 1, i)
is_prime[idx] = False
i += 1
def get_roots_logs(self, factor_base, n):
sqrt_p = [0] * len(factor_base)
log_p = [0] * len(factor_base)
for i in range(1, len(factor_base)):
sqrt_p[i] = self.sqrt_mod_p(n, factor_base[i])
log_p[i] = log(factor_base[i])
return sqrt_p, log_p
def sqrt_mod_p(self, n, p):
if p == 2: return n % 2
if p % 4 == 3: return pow(n, (p + 1) // 4, p)
if p % 8 == 5:
out = pow(n, (p + 3) // 8, p)
if pow(out, 2, p) != n % p: out = out * pow(2, (p-1)//4, p) % p
return out % p
s = 0
t = p - 1
while (t % 2 == 0):
t = t // 2
s += 1
d = 2
while jacobi(d, p) == 1: d += 1
nt = pow(n, t, p)
dt = pow(d, t, p)
m = 0
for i in range(0, s-1):
T = pow(dt, m, p) * nt
T = pow(T, pow(2, s - 2 - i), p) + 1
if T == p: m = m + pow(2, i+1)
return (pow(n, (t+1)//2, p)*pow(dt, m//2, p)) % p
def set_next_polynomial(self):
self.D = self.D + (3 - self.D % 4) + 4
while not is_probable_prime(self.D) or pow(self.n, (self.D-1)//2, self.D) != 1: self.D += 4
self.A = pow(self.D, 2, self.n)
H1 = pow(self.n, (self.D + 1)//4, self.D)
t = gcd_ext(2*H1, self.D)[1] # inverse of 2*H1 mod D
H2 = t*((self.n-H1*H1)//self.D) % self.D
self.B = (H1 + H2*self.D) % self.A
self.C = (self.B*self.B - self.n)//self.A
self.D_inv = gcd_ext(self.D, self.n)[1] % self.n
self.root1 = [0] * (len(self.factor_base))
self.root2 = [0] * (len(self.factor_base))
for i in range(1, len(self.factor_base)):
B_p = self.B % self.factor_base[i]
A_inv_p = gcd_ext(self.A, self.factor_base[i])[1] % self.factor_base[i]
self.root1[i] = ((self.sqrt_p[i] - B_p) * A_inv_p ) % self.factor_base[i]
self.root2[i] = ((-self.sqrt_p[i] - B_p) * A_inv_p ) % self.factor_base[i]
def factor_into_two(self):
while 1:
self.set_next_polynomial()
factors = self.sieve()
if len(factors) > 1:
return factors
def sieve(self):
for window in range(-self.M//self.sieve_window_size, self.M//self.sieve_window_size):
sieve = np.zeros(self.sieve_window_size)
sieve_start = window * self.sieve_window_size
sieve_end = (window+1) * self.sieve_window_size - 0.5
qx_start = (self.A*sieve_start+2*self.B)*sieve_start+self.C
qx_end = (self.A*sieve_end+2*self.B)*sieve_end+self.C
threshold = log(min(abs(qx_start), abs(qx_end))) - 1
for f in range(1, len(self.factor_base)):
r = sieve_start % self.factor_base[f]
r1 = (self.root1[f] - r) % self.factor_base[f]
r2 = (self.root2[f] - r) % self.factor_base[f]
idx = np.concatenate((np.arange(r1, self.sieve_window_size, self.factor_base[f]), np.arange(r2, self.sieve_window_size, self.factor_base[f])))
sieve[idx] += self.log_p[f]
for s in np.where(sieve > threshold)[0]:
exponent_vector = np.zeros(len(self.factor_base), dtype=int)
x = sieve_start + int(s)
X = self.A*x + self.B
qx = (X + self.B)*x + self.C
X = X*self.D_inv % self.n
if qx < 0:
qx = self.n - qx
exponent_vector[0] = 1
for f in range(1, len(self.factor_base)):
r = x % self.factor_base[f]
if (r != self.root1[f] and r != self.root2[f]): continue
while qx >= self.factor_base[f] and qx % self.factor_base[f] == 0:
qx = qx // self.factor_base[f]
exponent_vector[f] += 1
if qx == 1: # b-smooth Q(x) has been found
Y = 1
idx = np.where(exponent_vector[1:]>=2)[0]+1
for f in idx:
Y = (Y * pow(self.factor_base[f], int(exponent_vector[f]) // 2)) % self.n
exponent_vector[idx] %= 2
g = gcd_ext(X + Y, self.n)[0] % self.n
if g > 1 and g < self.n-1: return [g, self.n//g]
exponent_vector_root = np.zeros(len(self.factor_base), dtype=int)
for f in np.arange(len(self.factor_base)):
if exponent_vector[f] % 2 == 0: continue
if self.row[f] >= 0:
X = X*self.X_list[self.row[f]] % self.n
Y = Y*self.Y_list[self.row[f]] % self.n
exponent_vector += self.exponent_vectors[self.row[f], :]
idx = np.where(exponent_vector>=2)[0]
exponent_vector_root[idx] += exponent_vector[idx] // 2
exponent_vector[idx] %= 2
if not np.any(exponent_vector[f:]):
for i in np.where(exponent_vector_root[1:]>0)[0]+1:
Y *= pow(self.factor_base[i], int(exponent_vector_root[i])) % self.n
exponent_vector_root = np.zeros(len(self.factor_base), dtype=int)
g = gcd_ext(X + Y, self.n)[0] % self.n
if g > 1 and g < self.n-1: return [g, self.n//g]
else:
for i in np.where(exponent_vector_root[1:]>0)[0]+1:
Y *= pow(self.factor_base[i], int(exponent_vector_root[i])) % self.n
self.row[f] = self.numbers_found
self.X_list[self.numbers_found] = X
self.Y_list[self.numbers_found] = Y
self.exponent_vectors[self.numbers_found, :] = exponent_vector % 2 == 1
self.numbers_found += 1
print("\rNumber of b-smooth numbers found ({} needed at most): {}".format(len(self.factor_base), self.numbers_found), end="")
break
return [self.n]
The following classes implement asymmetric encryption with RSA, ElGamal, and symmetric encryption based on BBS (Blum-Blum-Shub, see https://shub.ccny.cuny.edu/articles/1986-A_simple_unpredictable_pseudo-random_number_generator.pdf).
class RSA:
def __init__(self, p=0, q=0, e=0):
size = 100000000000000000
if p==0: p = next_probable_prime(random.randint(1*size, 2*size))
if q==0: q = next_probable_prime(random.randint(3*size, 4*size))
if e==0:
self.e = next_probable_prime(random.randint(0, (p-1)*(q-1)))
while gcd_ext(self.e, (p-1)*(q-1))[0] != 1: self.e = (self.e + 1) % ((p-1)*(q-1))
else: self.e = e
self.n = p*q
if gcd_ext(self.e, (p-1)*(q-1))[0] != 1: raise ValueError("e and (p-1)*(q-1) need to be coprime")
self._d = gcd_ext(self.e, (p-1)*(q-1))[1] % ((p-1)*(q-1))
def get_public_key(self):
return self.n, self.e
def get_private_key(self):
return self.n, self._d
def encrypt(self, m, n=0, e=0):
if n==0: n = self.n
if e==0: e = self.e
if m>n: raise ValueError("m too large, n too small")
return pow(m, e, n)
def decrypt(self, c, n=0, d=0):
if n==0: n = self.n
if d==0: d = self._d
return pow(c, d, n)
class ElGamal:
def __init__(self, p=0, g=0):
size = 100000000000000000
if p==0:
self.p = 1
while not is_probable_prime(self.p):
p1, p2 = next_probable_prime(random.randint(1*size, 2*size)), next_probable_prime(random.randint(3*size, 4*size))
self.p = 2*p1*p2 + 1
self.g = random.randint(2, self.p-2)
while pow(self.g, 2*p1, self.p) == 1 or pow(self.g, 2*p2, self.p) == 1 or pow(self.g, p1*p2, self.p) == 1:
self.g = random.randint(2, self.p-2)
else:
self.p = p
self.g = g
self._x = random.randint(1, self.p-2)
self.h = pow(self.g, self._x, self.p)
self.nonce_list = []
def get_public_key(self):
return self.p, self.g, self.h
def get_private_key(self):
return self._x
def encrypt(self, m, p=0, g=0, h=0):
if p==0: p = self.p
if g==0: g = self.g
if h==0: h = self.h
nonce = random.randint(1, p-2)
while nonce in self.nonce_list:
nonce = random.randint(1, p-2)
self.nonce_list.append(nonce)
return m*pow(h, nonce, p) % p, pow(g, nonce, p)
def decrypt(self, c, d):
return c*gcd_ext(pow(d, self._x, self.p), self.p)[1] % self.p
class BBS:
def __init__(self, n=0, seed=0):
# create a pseudo-random sequence of period length 4*p1*p2*q1*q2, Blum-Blum-Shub
size = 10000000
self.skip = 1000
if n==0:
p, q = 0, 0
while not(is_probable_prime(p) and is_probable_prime(q)):
p1, p2 = random.randint(1*size, 2*size), random.randint(3*size, 4*size)
q1, q2 = random.randint(5*size, 6*size), random.randint(7*size, 8*size)
p, q = 2*p1*p2+1, 2*q1*q2+1
n = p*q
self.n = n
self.number_bytes = int(log(self.n)/log(256)+1)
self.nonce_list = []
if seed == 0: self._seed = random.randint(2,self.n)
else: self._seed = seed
def next(self, length, nonce):
current = (nonce + self._seed) % self.n
out = bytearray(length)
for i in range(self.skip):
current = pow(current, 2, self.n)
for i in range(length * 8):
current = pow(current, 2, self.n)
out[i//8] += (sum(current.to_bytes(self.number_bytes, "little")) % 2) << (i%8)
return out
def encrypt(self, message, nonce):
if nonce in self.nonce_list: raise ValueError("nonce was already used")
self.nonce_list.append(nonce)
bytes = message.encode()
return bytearray([x^y for x, y in zip(bytes, self.next(len(bytes), nonce))])
def decrypt(self, encrypted_message, nonce):
return bytearray([x^y for x, y in zip(encrypted_message, self.next(len(encrypted_message), nonce))]).decode()
class Person:
def __init__(self, name, secret_message):
self.name = name
self._secret_message = secret_message
self._rsa = RSA()
self._bbs = BBS()
self._elg = ElGamal()
def get_rsa_public_key(self):
return self._rsa.get_public_key()
def get_rsa_encrypted_seed(self, person):
return self._bbs.n, RSA().encrypt(self._bbs._seed, *person.get_rsa_public_key())
def get_elg_public_key(self):
return self._elg.get_public_key()
def get_elg_encrypted_seed(self, person):
c, d = ElGamal().encrypt(self._bbs._seed, *person.get_elg_public_key())
return self._bbs.n, c, d
def send_encrypted_message(self):
nonce = int(time.time()*1000)
encrypted_message = self._bbs.encrypt(self._secret_message, nonce)
return encrypted_message, nonce
def receive_encrypted_message_rsa(self, person):
bbs_n, bbs_encrypted_seed = person.get_rsa_encrypted_seed(self)
bbs_seed = self._rsa.decrypt(bbs_encrypted_seed)
person_bbs = BBS(bbs_n, bbs_seed)
encrypted_message, nonce = person.send_encrypted_message()
message = person_bbs.decrypt(encrypted_message, nonce)
print('{} received the secret message "{}" from {}'.format(self.name, message, person.name))
def receive_encrypted_message_elg(self, person):
bbs_n, bbs_encrypted_seed, d = person.get_elg_encrypted_seed(self)
bbs_seed = self._elg.decrypt(bbs_encrypted_seed, d)
person_bbs = BBS(bbs_n, bbs_seed)
encrypted_message, nonce = person.send_encrypted_message()
message = person_bbs.decrypt(encrypted_message, nonce)
print('{} received the secret message "{}" from {}'.format(self.name, message, person.name))
This section demonstrates the homomorphic capacites of RSA as a didactic exercise. As mentioned in the companion paper, RSA does not have semantic security, so it cannot be used securely for homomorphic processing in practical applications.
rsa = RSA(next_probable_prime(100000000000000000), next_probable_prime(1100000000000000000), 17)
cipher = rsa.encrypt(text_to_int("Data Science"))
print("Decrypting message with private key: {}".format(int_to_text(rsa.decrypt(cipher))))
print("\nAttacking RSA by factoring public key via quadratic sieve:")
n, e = rsa.get_public_key()
factors = MPQS(n).factor()
p, q = factors[0], factors[1]
d = gcd_ext(e, (p-1)*(q-1))[1] % ((p-1)*(q-1))
print(int_to_text(RSA().decrypt(cipher, n, d)))
print("\nHomomorphic RSA encryption example: multiplication <-> multiplication")
a = random.randint(0,1000000)
b = random.randint(0,1000000)
print(rsa.encrypt(a % rsa.n)*rsa.encrypt(b % rsa.n) % rsa.n)
print(rsa.encrypt(a*b % rsa.n) % rsa.n)
Decrypting message with private key: Data Science Attacking RSA by factoring public key via quadratic sieve: Trying to factor 36-digit number 110000000000000009600000000000000189 MPQS trying to factor 36-digit number 110000000000000009600000000000000189 Number of b-smooth numbers found (749 needed at most): 707 [1100000000000000063, 100000000000000003] Data Science Homomorphic RSA encryption example: multiplication <-> multiplication 82220119621987860348718849059267407 82220119621987860348718849059267407
alice = Person("Alice", "Alice's secret message")
bob = Person("Bob", "Bob's secret message")
bob.receive_encrypted_message_rsa(alice)
bob.receive_encrypted_message_elg(alice)
Bob received the secret message "Alice's secret message" from Alice Bob received the secret message "Alice's secret message" from Alice
cryptolib/hefloat¶This notebook exemplifies a fully encrypted workflow where the client has a database of individual data records and wants to apply a mortality model (logistic regression) to predict the mortality risk for the individuals represented in each data record.
For this, we use a light version of Tune Insight's public cryptolib/hefloat library, which provides a simplified interface to abstract cryptographic concepts and enable quick and simple manipulation of encrypted records. These are the followed steps:
In order to run the following cells, you need to have the Tune Insight hefloat package from tuneinsight-hefloat-0.4.2.tar.gz installed. Run the following cell to install it in your environment.
If you havea any questions or find any issues with the library, do not hesitate to contact us as contact@tuneinsight.com.
#Install the required Tune Insight package
%pip install tuneinsight-hefloat-0.4.2.tar.gz
The hefloat library works with approximate arithmetic and uses a cryptosystem adapted to machine learning operations (CKKS).
First, we select the required precision for the homomorphic computations (equivalent to the input scale $\Delta$) and the depth of the circuit (maximum number of consecutive products before the results are decrypted). The library automatically selects secure parameters as a function of the required circuit depth and precision.
from tuneinsight.cryptolib.hefloat import hefloat
# Parameterization: scale/precision and circuit depth
log_scale = 45 # Fixed-point arithmetic floating point scaling factor in bits (log2(Delta))
levels = 7 # Circuit depth
log_qi = [log_scale+5] + levels*[log_scale] # 5 additional bits for the lowest level, to account for plaintext growth
log_pi = [log_scale+5] # Auxiliary module used for relinearization (usually, at least of the same size as the lowest level q0)
# In order to generate an instance of the cryptosystem, the RLWE ring degree is automatically chosen to ensure at least 128-bit of security
# A context stores the scheme cryptographic parameters and a key generator
context = hefloat.new_context(log_qi = log_qi, log_default_scale= log_scale, log_pi = log_pi)
#Print some information about the cryptographic parameters
print(f'Log2 N: {context.parameters.log_n()}')
print(f'Log2 Moduli Chain: Q{log_qi} + P{log_pi}')
print(f'Log2 QP: {context.parameters.log_q() + context.parameters.log_p()}')
print(f'Log2 Slots: {context.parameters.log_slots()}')
print(f'Available Depth: {levels}')
The owner of the data generates a public-private key pair.
First, the secret key must remain at the data owner, as it is the key that can decrypt any ciphertext.
The evaluator object holds the public keys (inlcuding the relinearization key), in order to enable the encrypted execution of homomorphic circuits, but the evaluator alone cannot perform decryption operations, as the secret key is not stored in it.
Therefore, the client can send the evaluator and its public keys to the server where the encrypted computation must be carried out, with the guarantee that the server will not be able to access any clear-text data.
# Generate a fresh secret key
sk = context.new_secret_key()
# Instantiate an evaluator with a relinearization key
# The relinearization key is a public-evaluation key required to ensure ciphertext x ciphertext compactness
# The resulting evaluator object contains only public information and can be freely shared
evaluator = context.new_evaluator(context.new_relinearization_key(sk))
The cryptosystem naturally supports homomorphic polynomial operations, but non-polynomial operations have to be approximated.
A Chebyshev's polynomial approximation is numerically stable and enables a low approximation error (close to the minimax approximation polynomial) across the defined input interval. In this case (logistic regression), the to-be-approximated function is a Sigmoid. The approximation interval is chosen as a function of the domain spanned by the inputs to the Sigmoid function (after the scalar product between the input vector and the regression coefficient vector). The degree of the approximation polynomial ($2^8-1=63$ in our case) can be chosen to fit the available depth with which the cryptosystem has been parameterized.
For this example, we generate a table of uniformly distributed random data that represents $l=2^{13}=8192$ records with $k=200$ features. The model coefficients (regression coefficients $\beta_i$, including the bias or intercept coefficient $\beta_0$) are also randomly chosen. In a real scenario, these coefficients will be obtained from the model provider or from a prior training process. The latter can be run on local data or on third party or distributed data, in which case it can also leverage encrypted computation (see the companion notebook using the full Tune Insight SDK for an example of encrypted distributed training with dataset $D_2$).
import numpy.polynomial.chebyshev as chebyshev
import numpy as np
# Expected interval of the encrypted values after the scalar product
a = -12
b = 12
# Interpolates the Sigmoid in the interval [-12, 12] and returns the coefficients
# for the Chebyshev approximation polynomial in the Chebyshev basis
coeffs = chebyshev.chebinterpolate(lambda x: 1/(1+np.exp(-((b-a)/2 * x + (b+a)/2))), 63)
## Synthetic data generation:
# Number of samples to process in parallel (available plaintext slots that one encryption can hold)
batch_size = context.slots()
# Number of features (k=200)
features = 200
# Generate random data in [-0.5, 0.5]. This is the matrix A'
data = np.random.rand(batch_size, features)-0.5
# Generate random regression weights in [0, 1]. These represent beta_i, i=1,...,k
weights = np.random.rand(features, 1)
# Generate random bias (intercept coefficient) in [0, 1]. This represents beta_0
bias = np.random.rand(1)
In order to leverage the inherent parallelization of the cryptosystem (SIMD - Single Instruction Multiple Data) offered by the underlying polynomial/vector arithmetic, we encrypt input values packed into a polynomial/vector representation.
In this case, we exemplify vertical packing (see Figure 11 of the tutorial paper), where each encryption contains a vector of values from multiple data records (one column of the data matrix $A'$). The input values are encoded in the slots (batched), in order to enable component-wise homomorphic operations.
# This optional parameter defines whether the input vectors will be encoded in the coefficients domain (if batched=False)
# or in the slots domain (if batched=True). The latter is the default behavior, and it enables component-wise homomorphic operations
# (additions and products)
batched = True
# The encrypt function can receive a two-dimensional matrix as input, in which case it encrypts each row of the input matrix in one ciphertext.
# Therefore, we transpose the input A', in order to encrypt each column of A' in one ciphertext.
# We need to explicitly make a copy to ensure a correct memory
# alignment when passing C pointers of arrays to the Go wrapper.
# The function returns an object that stores a vector of ciphertexts.
encrypted_data = context.encrypt(data.transpose().copy(), sk, batched)
# As for the regression coefficients, we encrypt each of the weights replicated in all slots of the corresponding ciphertext.
# For this, we apply repetition coding (with tile) and pass the resulting matrix as input to the encrypt function, so that each row is encrypted in a separate ciphertext.
# The result is an object that stores a vector of ciphertexts, each containing one regression coefficient replicated in all its slots.
encrypted_weights = context.encrypt(np.tile(weights, (1, batch_size))* 2/(b-a), sk, batched)
# The intercept coefficient or bias is also encrypted in its own ciphertext, with the same repetition coding as all the other regression coefficients
encrypted_bias = context.encrypt(np.tile(bias, (1, batch_size))* 2/(b-a) + (-a-b)/(b-a), sk, batched)
The prediction for the logistic regression involves, for each data record, the computation of
$y_i=\mu(\beta_0+\sum_{j=1}^k a_{i,j}\beta_j)$.
This can be broken down into:
Thanks to the used vertical packing, these three subsequent operations can be executed in parallel over all the data records as if we were manipulating a single record (SIMD).
The scalar product is evaluated using homomorphic products and additions. The scalar_product function takes two vectors of ciphertexts, and computes the scalar product between them, which is the homomorphic equivalent of performing $l=8192$ parallel scalar products between the $l$ $200$-length data vectors with the regression coefficient vector. The addition of the intercept is also a parallel operation
# Encrypted evaluation of data @ weights computed as np.sum(data.transpose() * np.tile(bias, (1, batch_size)), axis=0)
# This is faster, but equivalent, to doing evaluator.sum(evaluator.mul(encrypted_data, encrypted_weights), axis=0)
encrypted_scalar_product = evaluator.scalar_product(encrypted_data, encrypted_weights)
# Encrypted evaluation of data @ weights + bias
encrypted_scalar_product_plus_bias = evaluator.add(encrypted_bias, encrypted_scalar_product)
The Sigmoid is evaluated through the approximation polynomial calculated above. The function polynomial takes the coefficientes of the Chebyshev decomposition of a given polynomial and evaluates this polynomial on the input encryption(s). This is also a SIMD operation, where the polynomial evaluation is homomorphically applied component-wise to the all the slots in the ciphertext, so the result is indeed the vector of predictions for all the input data records, one prediction in each of the slots of the resulting ciphertext.
# Encrypted evaluation of sigmoid(data @ weights + bias)
encrypted_prediction = evaluator.polynomial(encrypted_scalar_product_plus_bias, coeffs=coeffs, basis="Chebyshev")
The decryption requires the secret key to succeed, and produces as a result the vector of $l=8192$ predictions.
# Decrypt the values
prediction = context.decrypt(encrypted_prediction, sk)[:, :batch_size]
In order to measure the precision preserved by the whole encrypted process, we compute the same logistic regression on the cleartext data, and compute the log of the average square error, which gives us the preserved bits of precision, which roughly equal 32 bits in this case. The scale of the inputs can be modified when parameterizing the cryptosystem in the first step, in order to adapt the precision to the required target.
from math import log
# Finally, we evaluate the plaintext circuit
clear_target = 1/(np.exp(-(data @ weights + bias))+1)
# And compare with the decrypted result
print(f'Obtained: {prediction}')
print(f'Clear_tg: {clear_target.transpose()}')
print(f'Precision as -log2(avg_l2(obtained-clear_tg))): {-log(np.sqrt(np.sum((prediction-clear_target.transpose())**2))/batch_size, 2)}')
For more information on the encrypted statistical and machine learning capabilities of the full secure federated platform, do not hesitate to contact as at contact@tuneinsight.com.